diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 08cf2bc48b19..b4d1201c44d4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -420,6 +420,32 @@ public Connection getConnection(ConnectorSession session, JdbcSplit split) return connection; } + @Override + public PreparedQuery prepareQuery( + ConnectorSession session, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions) + { + try (Connection connection = connectionFactory.openConnection(session)) { + PreparedQuery preparedQuery = new QueryBuilder(this).prepareQuery( + session, + connection, + table.getRelationHandle(), + groupingSets, + columns, + columnExpressions, + table.getConstraint(), + Optional.empty()); + preparedQuery = preparedQuery.transformQuery(tryApplyLimit(table.getLimit())); + return preparedQuery; + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException @@ -428,9 +454,10 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio PreparedQuery preparedQuery = queryBuilder.prepareQuery( session, connection, - table.getRemoteTableName(), - table.getGroupingSets(), + table.getRelationHandle(), + Optional.empty(), columns, + ImmutableMap.of(), table.getConstraint(), split.getAdditionalPredicate()); preparedQuery = preparedQuery.transformQuery(tryApplyLimit(table.getLimit())); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 51f63de3cec3..81c450c9cf4b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -190,6 +190,17 @@ public void abortReadConnection(Connection connection) delegate.abortReadConnection(connection); } + @Override + public PreparedQuery prepareQuery( + ConnectorSession session, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions) + { + return delegate.prepareQuery(session, table, groupingSets, columns, columnExpressions); + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 11353a36f668..fa2dc5a5d52e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -141,6 +141,17 @@ public void abortReadConnection(Connection connection) delegate().abortReadConnection(connection); } + @Override + public PreparedQuery prepareQuery( + ConnectorSession session, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions) + { + return delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions); + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle tableHandle, List columnHandles) throws SQLException diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index 044462fe0fae..bfa12310576d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -91,6 +91,13 @@ default void abortReadConnection(Connection connection) // most drivers do not need this } + PreparedQuery prepareQuery( + ConnectorSession session, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions); + PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcColumnHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcColumnHandle.java index 1f74e888930a..22b7231c1590 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcColumnHandle.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcColumnHandle.java @@ -14,7 +14,6 @@ package io.trino.plugin.jdbc; import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Joiner; import io.trino.spi.connector.ColumnHandle; @@ -23,14 +22,12 @@ import java.util.Objects; import java.util.Optional; -import java.util.function.Function; import static java.util.Objects.requireNonNull; public final class JdbcColumnHandle implements ColumnHandle { - private final Optional expression; private final String columnName; private final JdbcTypeHandle jdbcTypeHandle; private final Type columnType; @@ -40,7 +37,7 @@ public final class JdbcColumnHandle // All and only required fields public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type columnType) { - this(Optional.empty(), columnName, jdbcTypeHandle, columnType, true, Optional.empty()); + this(columnName, jdbcTypeHandle, columnType, true, Optional.empty()); } /** @@ -49,7 +46,7 @@ public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type c @Deprecated public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type columnType, boolean nullable) { - this(Optional.empty(), columnName, jdbcTypeHandle, columnType, nullable, Optional.empty()); + this(columnName, jdbcTypeHandle, columnType, nullable, Optional.empty()); } /** @@ -58,14 +55,12 @@ public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type c @Deprecated @JsonCreator public JdbcColumnHandle( - @JsonProperty("expression") Optional expression, @JsonProperty("columnName") String columnName, @JsonProperty("jdbcTypeHandle") JdbcTypeHandle jdbcTypeHandle, @JsonProperty("columnType") Type columnType, @JsonProperty("nullable") boolean nullable, @JsonProperty("comment") Optional comment) { - this.expression = requireNonNull(expression, "expression is null"); this.columnName = requireNonNull(columnName, "columnName is null"); this.jdbcTypeHandle = requireNonNull(jdbcTypeHandle, "jdbcTypeHandle is null"); this.columnType = requireNonNull(columnType, "columnType is null"); @@ -73,12 +68,6 @@ public JdbcColumnHandle( this.comment = requireNonNull(comment, "comment is null"); } - @JsonProperty - public Optional getExpression() - { - return expression; - } - @JsonProperty public String getColumnName() { @@ -109,12 +98,6 @@ public Optional getComment() return comment; } - @JsonIgnore - public boolean isSynthetic() - { - return expression.isPresent(); - } - public ColumnMetadata getColumnMetadata() { return ColumnMetadata.builder() @@ -148,19 +131,11 @@ public int hashCode() public String toString() { return Joiner.on(":").skipNulls().join( - expression.orElse(null), columnName, columnType.getDisplayName(), jdbcTypeHandle.getJdbcTypeName().orElse(null)); } - public String toSqlExpression(Function identifierQuote) - { - requireNonNull(identifierQuote, "identifierQuote is null"); - return expression - .orElseGet(() -> identifierQuote.apply(columnName)); - } - public static Builder builder() { return new Builder(); @@ -173,7 +148,6 @@ public static Builder builderFrom(JdbcColumnHandle handle) public static final class Builder { - private Optional expression = Optional.empty(); private String columnName; private JdbcTypeHandle jdbcTypeHandle; private Type columnType; @@ -184,7 +158,6 @@ public Builder() {} private Builder(JdbcColumnHandle handle) { - this.expression = handle.getExpression(); this.columnName = handle.getColumnName(); this.jdbcTypeHandle = handle.getJdbcTypeHandle(); this.columnType = handle.getColumnType(); @@ -192,12 +165,6 @@ private Builder(JdbcColumnHandle handle) this.comment = handle.getComment(); } - public Builder setExpression(Optional expression) - { - this.expression = expression; - return this; - } - public Builder setColumnName(String columnName) { this.columnName = columnName; @@ -231,7 +198,6 @@ public Builder setComment(Optional comment) public JdbcColumnHandle build() { return new JdbcColumnHandle( - expression, columnName, jdbcTypeHandle, columnType, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java index a75d55c638d2..4cdf41aca08a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java @@ -56,7 +56,6 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalLong; -import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Functions.identity; @@ -71,7 +70,7 @@ public class JdbcMetadata implements ConnectorMetadata { - private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "_presto_generated_"; + private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "_pfgnrtd_"; private final JdbcClient jdbcClient; private final boolean allowDropTable; @@ -114,21 +113,6 @@ public Optional> applyFilter(C { JdbcTableHandle handle = (JdbcTableHandle) table; - if (handle.getGroupingSets().isPresent()) { - if (constraint.getSummary().isNone()) { - return Optional.empty(); - } - - Set constraintColumns = constraint.getSummary().getDomains().orElseThrow().keySet(); - List> groupingSets = handle.getGroupingSets().get(); - boolean canPushDown = groupingSets.stream() - .allMatch(groupingSet -> ImmutableSet.copyOf(groupingSet).containsAll(constraintColumns)); - - if (!canPushDown) { - return Optional.empty(); - } - } - TupleDomain oldDomain = handle.getConstraint(); TupleDomain newDomain = oldDomain.intersect(constraint.getSummary()); @@ -166,16 +150,25 @@ public Optional> applyFilter(C } handle = new JdbcTableHandle( - handle.getSchemaTableName(), - handle.getRemoteTableName(), + handle.getRelationHandle(), newDomain, - handle.getGroupingSets(), handle.getLimit(), handle.getColumns()); return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter)); } + private JdbcTableHandle flushAttributesAsQuery(ConnectorSession session, JdbcTableHandle handle) + { + List columns = jdbcClient.getColumns(session, handle); + PreparedQuery preparedQuery = jdbcClient.prepareQuery(session, handle, Optional.empty(), columns, ImmutableMap.of()); + return new JdbcTableHandle( + new JdbcQueryRelationHandle(preparedQuery), + TupleDomain.all(), + OptionalLong.empty(), + Optional.of(columns)); + } + @Override public Optional> applyProjection( ConnectorSession session, @@ -195,10 +188,8 @@ public Optional> applyProjecti return Optional.of(new ProjectionApplicationResult<>( new JdbcTableHandle( - handle.getSchemaTableName(), - handle.getRemoteTableName(), + handle.getRelationHandle(), handle.getConstraint(), - handle.getGroupingSets(), handle.getLimit(), Optional.of(newColumns)), projections, @@ -224,13 +215,10 @@ public Optional> applyAggrega JdbcTableHandle handle = (JdbcTableHandle) table; - if (handle.getLimit().isPresent()) { - // handle's limit is applied after aggregations, so we cannot apply aggregations if limit is already set - return Optional.empty(); - } + // Global aggregation is represented by [[]] + verify(!groupingSets.isEmpty(), "No grouping sets provided"); - if (handle.getGroupingSets().isPresent()) { - // table handle cannot express aggregation on top of aggregation + if (groupingSets.size() > 1 && !jdbcClient.supportsGroupingSets()) { return Optional.empty(); } @@ -239,11 +227,8 @@ public Optional> applyAggrega return Optional.empty(); } - // Global aggregation is represented by [[]] - verify(!groupingSets.isEmpty(), "No grouping sets provided"); - - if (groupingSets.size() > 1 && !jdbcClient.supportsGroupingSets()) { - return Optional.empty(); + if (handle.getLimit().isPresent()) { + handle = flushAttributesAsQuery(session, handle); } List columns = jdbcClient.getColumns(session, handle); @@ -255,6 +240,7 @@ public Optional> applyAggrega ImmutableList.Builder newColumns = ImmutableList.builder(); ImmutableList.Builder projections = ImmutableList.builder(); ImmutableList.Builder resultAssignments = ImmutableList.builder(); + ImmutableMap.Builder expressions = ImmutableMap.builder(); for (AggregateFunction aggregate : aggregates) { Optional expression = jdbcClient.implementAggregation(session, aggregate, assignments); if (expression.isEmpty()) { @@ -265,31 +251,47 @@ public Optional> applyAggrega syntheticNextIdentifier++; } + String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + syntheticNextIdentifier; JdbcColumnHandle newColumn = JdbcColumnHandle.builder() - .setExpression(Optional.of(expression.get().getExpression())) - .setColumnName(SYNTHETIC_COLUMN_NAME_PREFIX + syntheticNextIdentifier) + .setColumnName(columnName) .setJdbcTypeHandle(expression.get().getJdbcTypeHandle()) .setColumnType(aggregate.getOutputType()) .setComment(Optional.of("synthetic")) .build(); - syntheticNextIdentifier++; newColumns.add(newColumn); projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType())); resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType())); + expressions.put(columnName, expression.get().getExpression()); + + syntheticNextIdentifier++; } + List> groupingSetsAsJdbcColumnHandles = groupingSets.stream() + .map(groupingSet -> groupingSet.stream() + .map(JdbcColumnHandle.class::cast) + .collect(toImmutableList())) + .collect(toImmutableList()); + + List newColumnsList = newColumns.build(); + + PreparedQuery preparedQuery = jdbcClient.prepareQuery( + session, + handle, + Optional.of(groupingSetsAsJdbcColumnHandles), + ImmutableList.builder() + .addAll(groupingSetsAsJdbcColumnHandles.stream() + .flatMap(List::stream) + .distinct() + .iterator()) + .addAll(newColumnsList) + .build(), + expressions.build()); handle = new JdbcTableHandle( - handle.getSchemaTableName(), - handle.getRemoteTableName(), - handle.getConstraint(), - Optional.of(groupingSets.stream() - .map(groupingSet -> groupingSet.stream() - .map(JdbcColumnHandle.class::cast) - .collect(toImmutableList())) - .collect(toImmutableList())), - OptionalLong.empty(), // limit - Optional.of(newColumns.build())); + new JdbcQueryRelationHandle(preparedQuery), + TupleDomain.all(), + OptionalLong.empty(), + Optional.of(newColumnsList)); return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of())); } @@ -308,10 +310,8 @@ public Optional> applyLimit(Connect } handle = new JdbcTableHandle( - handle.getSchemaTableName(), - handle.getRemoteTableName(), + handle.getRelationHandle(), handle.getConstraint(), - handle.getGroupingSets(), OptionalLong.of(limit), handle.getColumns()); @@ -346,7 +346,11 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect for (JdbcColumnHandle column : jdbcClient.getColumns(session, handle)) { columnMetadata.add(column.getColumnMetadata()); } - return new ConnectorTableMetadata(handle.getSchemaTableName(), columnMetadata.build(), jdbcClient.getTableProperties(session, handle)); + SchemaTableName schemaTableName = handle.isNamedRelation() + ? handle.getSchemaTableName() + // TODO (https://github.com/trinodb/trino/issues/6694) SchemaTableName should not be required for synthetic ConnectorTableHandle + : new SchemaTableName("_prepared", "query"); + return new ConnectorTableMetadata(schemaTableName, columnMetadata.build(), jdbcClient.getTableProperties(session, handle)); } @Override @@ -436,7 +440,6 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto verify(!((JdbcTableHandle) tableHandle).isSynthetic(), "Not a table reference: %s", tableHandle); List columnHandles = columns.stream() .map(JdbcColumnHandle.class::cast) - .peek(columnHandle -> verify(!columnHandle.isSynthetic(), "Not a column reference: %s", columnHandle)) .collect(toImmutableList()); JdbcOutputTableHandle handle = jdbcClient.beginInsertTable(session, (JdbcTableHandle) tableHandle, columnHandles); setRollback(() -> jdbcClient.rollbackCreateTable(session, handle)); @@ -463,7 +466,6 @@ public void setColumnComment(ConnectorSession session, ConnectorTableHandle tabl JdbcTableHandle tableHandle = (JdbcTableHandle) table; JdbcColumnHandle columnHandle = (JdbcColumnHandle) column; verify(!tableHandle.isSynthetic(), "Not a table reference: %s", tableHandle); - verify(!columnHandle.isSynthetic(), "Not a column reference: %s", columnHandle); jdbcClient.setColumnComment(session, tableHandle, columnHandle, comment); } @@ -481,7 +483,6 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle table, Col JdbcTableHandle tableHandle = (JdbcTableHandle) table; JdbcColumnHandle columnHandle = (JdbcColumnHandle) column; verify(!tableHandle.isSynthetic(), "Not a table reference: %s", tableHandle); - verify(!columnHandle.isSynthetic(), "Not a column reference: %s", columnHandle); jdbcClient.dropColumn(session, tableHandle, columnHandle); } @@ -491,7 +492,6 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle table, C JdbcTableHandle tableHandle = (JdbcTableHandle) table; JdbcColumnHandle columnHandle = (JdbcColumnHandle) column; verify(!tableHandle.isSynthetic(), "Not a table reference: %s", tableHandle); - verify(!columnHandle.isSynthetic(), "Not a column reference: %s", columnHandle); jdbcClient.renameColumn(session, tableHandle, columnHandle, target); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcNamedRelationHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcNamedRelationHandle.java new file mode 100644 index 000000000000..0fdd68349d21 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcNamedRelationHandle.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.SchemaTableName; + +import java.util.Objects; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class JdbcNamedRelationHandle + extends JdbcRelationHandle +{ + private final SchemaTableName schemaTableName; + private final RemoteTableName remoteTableName; + + @JsonCreator + public JdbcNamedRelationHandle( + @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + @JsonProperty("remoteTableName") RemoteTableName remoteTableName) + { + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTable is null"); + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @JsonProperty + public RemoteTableName getRemoteTableName() + { + return remoteTableName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JdbcNamedRelationHandle that = (JdbcNamedRelationHandle) o; + return Objects.equals(schemaTableName, that.schemaTableName) + // remoteTableName is not compared here, as required by TestJdbcTableHandle#testEquivalence TODO document why this is important + /**/; + } + + @Override + public int hashCode() + { + return Objects.hash(schemaTableName); + } + + @Override + public String toString() + { + return format("%s %s", schemaTableName, remoteTableName); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcQueryRelationHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcQueryRelationHandle.java new file mode 100644 index 000000000000..78291150f1f5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcQueryRelationHandle.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.lang.String.format; + +public class JdbcQueryRelationHandle + extends JdbcRelationHandle +{ + private final PreparedQuery preparedQuery; + + @JsonCreator + public JdbcQueryRelationHandle(PreparedQuery preparedQuery) + { + this.preparedQuery = preparedQuery; + } + + @JsonProperty + public PreparedQuery getPreparedQuery() + { + return preparedQuery; + } + + @Override + public String toString() + { + return format("Query[%s]", preparedQuery.getQuery()); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRelationHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRelationHandle.java new file mode 100644 index 000000000000..dfc2604d2a6e --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRelationHandle.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = JdbcNamedRelationHandle.class, name = "named"), + @JsonSubTypes.Type(value = JdbcQueryRelationHandle.class, name = "query"), +}) +public abstract class JdbcRelationHandle +{ + @Override + public abstract String toString(); +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java index 5b3ee7f5e4e0..4a6d311afe2b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java @@ -29,24 +29,21 @@ import java.util.Optional; import java.util.OptionalLong; -import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public final class JdbcTableHandle implements ConnectorTableHandle { - private final SchemaTableName schemaTableName; - private final RemoteTableName remoteTableName; - private final TupleDomain constraint; + private final JdbcRelationHandle relationHandle; - // semantically aggregation is applied after constraint - private final Optional>> groupingSets; + private final TupleDomain constraint; - // semantically limit is applied after aggregation + // semantically limit is applied after constraint private final OptionalLong limit; - // columns of the relation described by this handle, after projections, aggregations, etc. + // columns of the relation described by this handle private final Optional> columns; @Deprecated @@ -58,68 +55,64 @@ public JdbcTableHandle(SchemaTableName schemaTableName, @Nullable String catalog public JdbcTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTableName) { this( - schemaTableName, - remoteTableName, + new JdbcNamedRelationHandle(schemaTableName, remoteTableName), TupleDomain.all(), - Optional.empty(), OptionalLong.empty(), Optional.empty()); } @JsonCreator public JdbcTableHandle( - @JsonProperty("schemaTableName") SchemaTableName schemaTableName, - @JsonProperty("remoteTableName") RemoteTableName remoteTableName, + @JsonProperty("relationHandle") JdbcRelationHandle relationHandle, @JsonProperty("constraint") TupleDomain constraint, - @JsonProperty("groupingSets") Optional>> groupingSets, @JsonProperty("limit") OptionalLong limit, @JsonProperty("columns") Optional> columns) { - this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); - this.remoteTableName = requireNonNull(remoteTableName, "remoteTable is null"); + this.relationHandle = requireNonNull(relationHandle, "relationHandle is null"); this.constraint = requireNonNull(constraint, "constraint is null"); - requireNonNull(groupingSets, "groupingSets is null"); - checkArgument(groupingSets.isEmpty() || !groupingSets.get().isEmpty(), "Global aggregation should be represented by [[]]"); - this.groupingSets = groupingSets.map(JdbcTableHandle::copy); - this.limit = requireNonNull(limit, "limit is null"); requireNonNull(columns, "columns is null"); - checkArgument(groupingSets.isEmpty() || columns.isPresent(), "columns should be present when groupingSets is present"); this.columns = columns.map(ImmutableList::copyOf); } - @JsonProperty + @JsonIgnore public SchemaTableName getSchemaTableName() { - return schemaTableName; + return getNamedRelation().getSchemaTableName(); } - @JsonProperty + @JsonIgnore public RemoteTableName getRemoteTableName() { - return remoteTableName; + return getNamedRelation().getRemoteTableName(); + } + + @JsonProperty + public JdbcRelationHandle getRelationHandle() + { + return relationHandle; } @Deprecated @Nullable public String getCatalogName() { - return remoteTableName.getCatalogName().orElse(null); + return getRemoteTableName().getCatalogName().orElse(null); } @Deprecated @Nullable public String getSchemaName() { - return remoteTableName.getSchemaName().orElse(null); + return getRemoteTableName().getSchemaName().orElse(null); } @Deprecated public String getTableName() { - return remoteTableName.getTableName(); + return getRemoteTableName().getTableName(); } @JsonProperty @@ -128,12 +121,6 @@ public TupleDomain getConstraint() return constraint; } - @JsonProperty - public Optional>> getGroupingSets() - { - return groupingSets; - } - @JsonProperty public OptionalLong getLimit() { @@ -146,10 +133,22 @@ public Optional> getColumns() return columns; } + private JdbcNamedRelationHandle getNamedRelation() + { + checkState(isNamedRelation(), "The table handle does not represent a named relation: %s", this); + return (JdbcNamedRelationHandle) relationHandle; + } + @JsonIgnore public boolean isSynthetic() { - return !constraint.isAll() || groupingSets.isPresent() || limit.isPresent(); + return !isNamedRelation() || !constraint.isAll() || limit.isPresent(); + } + + @JsonIgnore + public boolean isNamedRelation() + { + return relationHandle instanceof JdbcNamedRelationHandle; } @Override @@ -162,9 +161,8 @@ public boolean equals(Object obj) return false; } JdbcTableHandle o = (JdbcTableHandle) obj; - return Objects.equals(this.schemaTableName, o.schemaTableName) && + return Objects.equals(this.relationHandle, o.relationHandle) && Objects.equals(this.constraint, o.constraint) && - Objects.equals(this.groupingSets, o.groupingSets) && Objects.equals(this.limit, o.limit) && Objects.equals(this.columns, o.columns); } @@ -172,18 +170,16 @@ public boolean equals(Object obj) @Override public int hashCode() { - return Objects.hash(schemaTableName, constraint, groupingSets, limit, columns); + return Objects.hash(relationHandle, constraint, limit, columns); } @Override public String toString() { StringBuilder builder = new StringBuilder(); - builder.append(schemaTableName).append(" "); - builder.append(remoteTableName); + builder.append(relationHandle); limit.ifPresent(value -> builder.append(" limit=").append(value)); columns.ifPresent(value -> builder.append(" columns=").append(value)); - groupingSets.ifPresent(value -> builder.append(" groupingSets=").append(value)); return builder.toString(); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index 76ae10507456..030bf94f6cf7 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -16,10 +16,12 @@ import com.google.common.base.Joiner; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; @@ -78,9 +80,13 @@ public PreparedStatement buildSql( PreparedQuery preparedQuery = prepareQuery( session, connection, - remoteTableName, + new JdbcNamedRelationHandle( + // This dummy SchemaTableName is not used for anything here. It's provided only to implement the deprecated buildSql() method + new SchemaTableName(remoteTableName.getSchemaName().orElse(""), remoteTableName.getTableName()), + remoteTableName), groupingSets, columns, + ImmutableMap.of(), tupleDomain, additionalPredicate); preparedQuery = preparedQuery.transformQuery(sqlFunction); @@ -90,17 +96,37 @@ public PreparedStatement buildSql( public PreparedQuery prepareQuery( ConnectorSession session, Connection connection, - RemoteTableName remoteTableName, + JdbcRelationHandle baseRelation, Optional>> groupingSets, List columns, + Map columnExpressions, TupleDomain tupleDomain, Optional additionalPredicate) { - String sql = "SELECT " + getProjection(columns); - sql += " FROM " + getRelation(remoteTableName); + if (!tupleDomain.isNone()) { + Map domains = tupleDomain.getDomains().orElseThrow(); + columns.stream() + .filter(domains::containsKey) + .filter(column -> columnExpressions.containsKey(column.getColumnName())) + .findFirst() + .ifPresent(column -> { throw new IllegalArgumentException(format("Column %s has an expression and a constraint attached at the same time", column)); }); + } ImmutableList.Builder accumulator = ImmutableList.builder(); + String sql = "SELECT " + getProjection(columns, columnExpressions); + if (baseRelation instanceof JdbcNamedRelationHandle) { + sql += " FROM " + getRelation(((JdbcNamedRelationHandle) baseRelation).getRemoteTableName()); + } + else if (baseRelation instanceof JdbcQueryRelationHandle) { + PreparedQuery preparedQuery = ((JdbcQueryRelationHandle) baseRelation).getPreparedQuery(); + sql += " FROM (" + preparedQuery.getQuery() + ") o"; + accumulator.addAll(preparedQuery.getParameters()); + } + else { + throw new IllegalArgumentException("Unsupported relation: " + baseRelation); + } + List clauses = toConjuncts(client, session, connection, tupleDomain, accumulator::add); if (additionalPredicate.isPresent()) { clauses = ImmutableList.builder() @@ -161,7 +187,7 @@ protected String getRelation(RemoteTableName remoteTableName) return client.quoted(remoteTableName); } - protected String getProjection(List columns) + protected String getProjection(List columns, Map columnExpressions) { if (columns.isEmpty()) { return "1 x"; @@ -169,10 +195,10 @@ protected String getProjection(List columns) return columns.stream() .map(jdbcColumnHandle -> { String columnAlias = client.quoted(jdbcColumnHandle.getColumnName()); - if (jdbcColumnHandle.getExpression().isEmpty()) { + String expression = columnExpressions.get(jdbcColumnHandle.getColumnName()); + if (expression == null) { return columnAlias; } - String expression = jdbcColumnHandle.toSqlExpression(client::quoted); return format("%s AS %s", expression, columnAlias); }) .collect(joining(", ")); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgDecimal.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgDecimal.java index 988232c717cf..d9f47b4eee20 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgDecimal.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgDecimal.java @@ -61,7 +61,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(type)); return Optional.of(new JdbcExpression( - format("CAST(avg(%s) AS decimal(%s, %s))", columnHandle.toSqlExpression(context.getIdentifierQuote()), type.getPrecision(), type.getScale()), + format("CAST(avg(%s) AS decimal(%s, %s))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgFloatingPoint.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgFloatingPoint.java index 0d10361a10a1..7dbdcef736da 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgFloatingPoint.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementAvgFloatingPoint.java @@ -61,7 +61,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); return Optional.of(new JdbcExpression( - format("avg(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("avg(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCount.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCount.java index 4fc22124dcc4..13ab557c7ece 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCount.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementCount.java @@ -70,7 +70,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType() == BIGINT); return Optional.of(new JdbcExpression( - format("count(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("count(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), bigintTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java index 705f80e5772f..7614ca4724e5 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementMinMax.java @@ -56,7 +56,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType())); return Optional.of(new JdbcExpression( - format("%s(%s)", aggregateFunction.getFunctionName(), columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("%s(%s)", aggregateFunction.getFunctionName(), context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementSum.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementSum.java index 7bfb3860cdc7..dc2281f6e89b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementSum.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementSum.java @@ -79,7 +79,7 @@ else if (aggregateFunction.getOutputType() instanceof DecimalType) { } return Optional.of(new JdbcExpression( - format("sum(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("sum(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), resultTypeHandle)); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java index d7c6b78fa062..4020bc33d2ef 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java @@ -23,6 +23,7 @@ public final class JdbcClientStats private final JdbcApiStats beginCreateTable = new JdbcApiStats(); private final JdbcApiStats beginInsertTable = new JdbcApiStats(); private final JdbcApiStats buildInsertSql = new JdbcApiStats(); + private final JdbcApiStats prepareQuery = new JdbcApiStats(); private final JdbcApiStats buildSql = new JdbcApiStats(); private final JdbcApiStats commitCreateTable = new JdbcApiStats(); private final JdbcApiStats createSchema = new JdbcApiStats(); @@ -86,6 +87,13 @@ public JdbcApiStats getBuildInsertSql() return buildInsertSql; } + @Managed + @Nested + public JdbcApiStats getPrepareQuery() + { + return prepareQuery; + } + @Managed @Nested public JdbcApiStats getBuildSql() diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index f98e2c27850e..7d7ff061bb7b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -21,6 +21,7 @@ import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; @@ -158,6 +159,17 @@ public void abortReadConnection(Connection connection) stats.getAbortReadConnection().wrap(() -> delegate().abortReadConnection(connection)); } + @Override + public PreparedQuery prepareQuery( + ConnectorSession session, + JdbcTableHandle table, + Optional>> groupingSets, + List columns, + Map columnExpressions) + { + return stats.getPrepareQuery().wrap(() -> delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions)); + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle tableHandle, List columnHandles) throws SQLException diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcColumnHandle.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcColumnHandle.java index d0b2c6ffeac1..fbba9dd8485c 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcColumnHandle.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcColumnHandle.java @@ -30,7 +30,7 @@ public class TestJdbcColumnHandle @Test public void testJsonRoundTrip() { - assertJsonRoundTrip(COLUMN_CODEC, new JdbcColumnHandle(Optional.empty(), "columnName", JDBC_VARCHAR, VARCHAR, true, Optional.of("some comment"))); + assertJsonRoundTrip(COLUMN_CODEC, new JdbcColumnHandle("columnName", JDBC_VARCHAR, VARCHAR, true, Optional.of("some comment"))); } @Test diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadata.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadata.java index 6927000a7d99..4787789fbffe 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadata.java @@ -307,7 +307,15 @@ public void testCombineFiltersWithAggregationPushdown() JdbcTableHandle tableHandleWithFilter = applyFilter(session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, secondDomain)))); assertEquals( tableHandleWithFilter.getConstraint().getDomains(), - Optional.of(ImmutableMap.of(groupByColumn, Domain.singleValue(VARCHAR, utf8Slice("one"))))); + // The query effectively intersects firstDomain and secondDomain, but this is not visible in JdbcTableHandle.constraint, + // as firstDomain has been converted into a PreparedQuery + Optional.of(ImmutableMap.of(groupByColumn, secondDomain))); + assertEquals( + ((JdbcQueryRelationHandle) tableHandleWithFilter.getRelationHandle()).getPreparedQuery().getQuery(), + "SELECT \"TEXT\", count(*) AS \"_pfgnrtd_1\" " + + "FROM \"" + database.getDatabaseName() + "\".\"EXAMPLE\".\"NUMBERS\" " + + "WHERE \"TEXT\" IN (?,?) " + + "GROUP BY \"TEXT\""); } @Test @@ -324,11 +332,18 @@ public void testNonGroupKeyPredicatePushdown() ConnectorTableHandle aggregatedTable = applyCountAggregation(session, baseTableHandle, ImmutableList.of(ImmutableList.of(groupByColumn))); Domain domain = Domain.singleValue(BIGINT, 123L); - Optional> filterResult = metadata.applyFilter( + JdbcTableHandle tableHandleWithFilter = applyFilter( session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(nonGroupByColumn, domain)))); - assertThat(filterResult).isEmpty(); + assertEquals( + tableHandleWithFilter.getConstraint().getDomains(), + Optional.of(ImmutableMap.of(nonGroupByColumn, domain))); + assertEquals( + ((JdbcQueryRelationHandle) tableHandleWithFilter.getRelationHandle()).getPreparedQuery().getQuery(), + "SELECT \"TEXT\", count(*) AS \"_pfgnrtd_1\" " + + "FROM \"" + database.getDatabaseName() + "\".\"EXAMPLE\".\"NUMBERS\" " + + "GROUP BY \"TEXT\""); } @Test @@ -346,11 +361,18 @@ public void tesMultiGroupKeyPredicatePushdown() ConnectorTableHandle aggregatedTable = applyCountAggregation(session, baseTableHandle, ImmutableList.of(ImmutableList.of(textColumn, valueColumn), ImmutableList.of(textColumn))); Domain domain = Domain.singleValue(BIGINT, 123L); - Optional> filterResult = metadata.applyFilter( + JdbcTableHandle tableHandleWithFilter = applyFilter( session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(valueColumn, domain)))); - assertThat(filterResult).isEmpty(); + assertEquals( + tableHandleWithFilter.getConstraint().getDomains(), + Optional.of(ImmutableMap.of(valueColumn, domain))); + assertEquals( + ((JdbcQueryRelationHandle) tableHandleWithFilter.getRelationHandle()).getPreparedQuery().getQuery(), + "SELECT \"TEXT\", \"VALUE\", count(*) AS \"_pfgnrtd_1\" " + + "FROM \"" + database.getDatabaseName() + "\".\"EXAMPLE\".\"NUMBERS\" " + + "GROUP BY GROUPING SETS ((\"TEXT\", \"VALUE\"), (\"TEXT\"))"); } private JdbcTableHandle applyCountAggregation(ConnectorSession session, ConnectorTableHandle tableHandle, List> groupByColumns) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcQueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcQueryBuilder.java index 26cc4c0ba918..33a88bae2e71 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcQueryBuilder.java @@ -20,6 +20,7 @@ import com.google.common.collect.Multiset; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.SortedRangeSet; @@ -45,6 +46,7 @@ import java.time.LocalTime; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.stream.LongStream; @@ -87,7 +89,9 @@ @Test(singleThreaded = true) public class TestJdbcQueryBuilder { - private static final RemoteTableName TEST_TABLE = new RemoteTableName(Optional.empty(), Optional.empty(), "test_table"); + private static final JdbcNamedRelationHandle TEST_TABLE = new JdbcNamedRelationHandle(new SchemaTableName( + "some_test_schema", "test_table"), + new RemoteTableName(Optional.empty(), Optional.empty(), "test_table")); private static final ConnectorSession SESSION = TestingConnectorSession.builder() .setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) .build(); @@ -220,7 +224,7 @@ public void testNormalBuildSql() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -268,6 +272,7 @@ public void testBuildSqlWithDomainComplement() TEST_TABLE, Optional.empty(), List.of(columns.get(0), columns.get(3), columns.get(9)), + Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { @@ -301,7 +306,7 @@ public void testBuildSqlWithFloat() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -335,7 +340,7 @@ public void testBuildSqlWithVarchar() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -371,7 +376,7 @@ public void testBuildSqlWithChar() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -412,7 +417,7 @@ public void testBuildSqlWithDateTime() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -453,7 +458,7 @@ public void testBuildSqlWithTimestamp() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -485,7 +490,7 @@ public void testBuildSqlWithLimit() Connection connection = database.getConnection(); Function function = sql -> sql + " LIMIT 10"; QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, TupleDomain.all(), Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), TupleDomain.all(), Optional.empty()); preparedQuery = preparedQuery.transformQuery(function); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + @@ -513,7 +518,7 @@ public void testEmptyBuildSql() Connection connection = database.getConnection(); QueryBuilder queryBuilder = new QueryBuilder(jdbcClient); - PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty()); + PreparedQuery preparedQuery = queryBuilder.prepareQuery(SESSION, connection, TEST_TABLE, Optional.empty(), columns, Map.of(), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { assertThat(preparedQuery.getQuery()).isEqualTo("" + "SELECT \"col_0\", \"col_1\", \"col_2\", \"col_3\", \"col_4\", \"col_5\", " + @@ -533,7 +538,6 @@ public void testAggregation() List projectedColumns = ImmutableList.of( this.columns.get(2), new JdbcColumnHandle( - Optional.of("sum(\"col_0\")"), "s", JDBC_BIGINT, BIGINT, @@ -548,6 +552,7 @@ public void testAggregation() TEST_TABLE, Optional.of(ImmutableList.of(ImmutableList.of(this.columns.get(2)))), projectedColumns, + Map.of("s", "sum(\"col_0\")"), TupleDomain.all(), Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { @@ -576,7 +581,6 @@ public void testAggregationWithFilter() List projectedColumns = ImmutableList.of( this.columns.get(2), new JdbcColumnHandle( - Optional.of("sum(\"col_0\")"), "s", JDBC_BIGINT, BIGINT, @@ -591,6 +595,7 @@ public void testAggregationWithFilter() TEST_TABLE, Optional.of(ImmutableList.of(ImmutableList.of(this.columns.get(2)))), projectedColumns, + Map.of("s", "sum(\"col_0\")"), tupleDomain, Optional.empty()); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(SESSION, connection, preparedQuery)) { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java index cbcdbacda3fa..c0c6b3cd51af 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java @@ -184,10 +184,8 @@ public void testTupleDomain() private RecordCursor getCursor(JdbcTableHandle jdbcTableHandle, List columns, TupleDomain domain) { jdbcTableHandle = new JdbcTableHandle( - jdbcTableHandle.getSchemaTableName(), - jdbcTableHandle.getRemoteTableName(), + jdbcTableHandle.getRelationHandle(), domain, - Optional.empty(), OptionalLong.empty(), Optional.empty()); diff --git a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java index 3c22ab897b72..15c9d75d559c 100644 --- a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java +++ b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java @@ -19,10 +19,12 @@ import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcNamedRelationHandle; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; @@ -43,6 +45,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; @@ -155,23 +158,36 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) return legacyToWriteMapping(session, type); } + @Override + public PreparedQuery prepareQuery(ConnectorSession session, JdbcTableHandle table, Optional>> groupingSets, List columns, Map columnExpressions) + { + return super.prepareQuery(session, prepareTableHandleForQuery(table), groupingSets, columns, columnExpressions); + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException { - String schemaName = table.getSchemaName(); - checkArgument("druid".equals(schemaName), "Only \"druid\" schema is supported"); + return super.buildSql(session, connection, split, prepareTableHandleForQuery(table), columns); + } + + private JdbcTableHandle prepareTableHandleForQuery(JdbcTableHandle table) + { + if (table.isNamedRelation()) { + String schemaName = table.getSchemaName(); + checkArgument("druid".equals(schemaName), "Only \"druid\" schema is supported"); - table = new JdbcTableHandle( - table.getSchemaTableName(), - // Druid doesn't like table names to be qualified with catalog names in the SQL query, hence we null out the catalog. - new RemoteTableName(Optional.empty(), table.getRemoteTableName().getSchemaName(), table.getRemoteTableName().getTableName()), - table.getConstraint(), - table.getGroupingSets(), - table.getLimit(), - table.getColumns()); + table = new JdbcTableHandle( + new JdbcNamedRelationHandle( + table.getSchemaTableName(), + // Druid doesn't like table names to be qualified with catalog names in the SQL query, hence we null out the catalog. + new RemoteTableName(Optional.empty(), table.getRemoteTableName().getSchemaName(), table.getRemoteTableName().getTableName())), + table.getConstraint(), + table.getLimit(), + table.getColumns()); + } - return super.buildSql(session, connection, split, table, columns); + return table; } /* diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java index 491129efe28d..4f07cb6bf1c5 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java @@ -61,7 +61,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType() == DOUBLE); return Optional.of(new JdbcExpression( - format("avg((%s * 1.0))", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); } } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java index 5bba1554d0ec..ef0b7e22a3a5 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java @@ -237,7 +237,7 @@ public void testAggregationPushdown() "SELECT regionkey, sum(nationkey) " + "FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " + "GROUP BY regionkey")) - .isNotFullyPushedDown(AggregationNode.class); + .isFullyPushedDown(); // decimals try (AutoCloseable ignoreTable = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) { @@ -349,7 +349,7 @@ public void testPredicatePushdown() // predicate over aggregation result assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77")) .matches("VALUES (BIGINT '3', BIGINT '77')") - .isNotFullyPushedDown(FilterNode.class); + .isFullyPushedDown(); } private AutoCloseable withTable(String tableName, String tableDefinition) diff --git a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixClient.java b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixClient.java index f19c74e9a521..9e8ce8fa5079 100644 --- a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixClient.java +++ b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixClient.java @@ -247,7 +247,7 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio session, connection, table.getRemoteTableName(), - table.getGroupingSets(), + Optional.empty(), columnHandles, phoenixSplit.getConstraint(), split.getAdditionalPredicate(), diff --git a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixSplitManager.java b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixSplitManager.java index 452873450da1..87fa7bd8bd0d 100644 --- a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixSplitManager.java +++ b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixSplitManager.java @@ -86,7 +86,7 @@ public ConnectorSplitSource getSplits( session, connection, tableHandle.getRemoteTableName(), - tableHandle.getGroupingSets(), + Optional.empty(), columns, tableHandle.getConstraint(), Optional.empty(), diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java index 1fcc11f80d02..410c480eafa5 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java @@ -58,7 +58,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType() == DOUBLE); return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS double precision))", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java index ce41763124a0..610a3a26a97d 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java @@ -254,7 +254,7 @@ public void testPredicatePushdown() // predicate over aggregation result assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77")) .matches("VALUES (BIGINT '3', BIGINT '77')") - .isNotFullyPushedDown(FilterNode.class); + .isFullyPushedDown(); } @Test @@ -380,7 +380,7 @@ public void testAggregationPushdown() assertThat(query("SELECT count(nationkey) FROM nation")).isFullyPushedDown(); assertThat(query("SELECT count(1) FROM nation")).isFullyPushedDown(); assertThat(query("SELECT count() FROM nation")).isFullyPushedDown(); - assertThat(query("SELECT count(DISTINCT regionkey) FROM nation")).isNotFullyPushedDown(AggregationNode.class); + assertThat(query("SELECT count(DISTINCT regionkey) FROM nation")).isFullyPushedDown(); // GROUP BY assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); @@ -401,7 +401,7 @@ public void testAggregationPushdown() "SELECT regionkey, sum(nationkey) " + "FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " + "GROUP BY regionkey")) - .isNotFullyPushedDown(AggregationNode.class); + .isFullyPushedDown(); // decimals try (AutoCloseable ignore = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) { diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java index 00450c3fbf45..d8d474a2d1d4 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java @@ -58,7 +58,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType() == DOUBLE); return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS double precision))", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java index 4071064bb721..32a7394c4e27 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java @@ -60,7 +60,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(DOUBLE)); return Optional.of(new JdbcExpression( - format("STDEVP(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("STDEVP(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java index 716113310b54..ca7e1599eb77 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java @@ -63,7 +63,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(DOUBLE)); return Optional.of(new JdbcExpression( - format("STDEV(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("STDEV(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java index 317f9f41c7ec..b568ecfce07e 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java @@ -63,7 +63,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(DOUBLE)); return Optional.of(new JdbcExpression( - format("VAR(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("VAR(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java index 2f59c9356b85..1ca0e7c7e894 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java @@ -60,7 +60,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap verify(aggregateFunction.getOutputType().equals(DOUBLE)); return Optional.of(new JdbcExpression( - format("VARP(%s)", columnHandle.toSqlExpression(context.getIdentifierQuote())), + format("VARP(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), columnHandle.getJdbcTypeHandle())); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 4b6a6d8e7627..cd313f76c7ad 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -367,6 +367,9 @@ protected String createTableSql(RemoteTableName remoteTableName, List co @Override public Map getTableProperties(ConnectorSession session, JdbcTableHandle tableHandle) { + if (!tableHandle.isNamedRelation()) { + return ImmutableMap.of(); + } try (Connection connection = configureConnectionTransactionIsolation(connectionFactory.openConnection(session)); Handle handle = Jdbi.open(connection)) { return getTableDataCompression(handle, tableHandle) diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java index 07fa26106b41..2b73067ff252 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java @@ -18,7 +18,6 @@ import io.trino.Session; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.FilterNode; -import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.AbstractTestIntegrationSmokeTest; import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; @@ -114,7 +113,7 @@ public void testAggregationPushdown() "SELECT regionkey, sum(nationkey) " + "FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " + "GROUP BY regionkey")) - .isNotFullyPushedDown(AggregationNode.class); + .isFullyPushedDown(); // decimals try (AutoCloseable ignoreTable = withTable("test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10), varchar_column varchar(10))")) { @@ -148,9 +147,9 @@ public void testAggregationPushdown() assertThat(query("SELECT min(varchar_column) FROM test_aggregation_pushdown WHERE varchar_column ='ala'")).isFullyPushedDown(); // not supported yet - assertThat(query("SELECT min(DISTINCT short_decimal) FROM test_aggregation_pushdown")).isNotFullyPushedDown(AggregationNode.class); + assertThat(query("SELECT min(DISTINCT short_decimal) FROM test_aggregation_pushdown")).isFullyPushedDown(); assertThat(query("SELECT DISTINCT short_decimal, min(long_decimal) FROM test_aggregation_pushdown GROUP BY short_decimal")) - .isNotFullyPushedDown(AggregationNode.class, ProjectNode.class); + .isFullyPushedDown(); } // array_agg returns array, which is not supported @@ -316,7 +315,7 @@ public void testPredicatePushdown() // predicate over aggregation result assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77")) .matches("VALUES (BIGINT '3', BIGINT '77')") - .isNotFullyPushedDown(FilterNode.class); + .isFullyPushedDown(); // decimals try (AutoCloseable ignoreTable = withTable("test_decimal_pushdown",