diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java index 8a679d88c3d3..dc0e12b2c200 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java @@ -41,7 +41,7 @@ public class OptimizerConfig private int maxReorderedJoins = 9; private boolean enableStatsCalculator = true; - private boolean statisticsPrecalculationForPushdownEnabled; + private boolean statisticsPrecalculationForPushdownEnabled = true; private boolean collectPlanStatisticsForAllQueries; private boolean ignoreStatsCalculatorFailures = true; private boolean defaultFilterFactorEnabled; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java index 54ab07e1b21c..fcefbb35d214 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java @@ -53,7 +53,7 @@ public void testDefaults() .setUsePreferredWritePartitioning(true) .setPreferredWritePartitioningMinNumberOfPartitions(50) .setEnableStatsCalculator(true) - .setStatisticsPrecalculationForPushdownEnabled(false) + .setStatisticsPrecalculationForPushdownEnabled(true) .setCollectPlanStatisticsForAllQueries(false) .setIgnoreStatsCalculatorFailures(true) .setDefaultFilterFactorEnabled(false) @@ -96,7 +96,7 @@ public void testExplicitPropertyMappings() .put("memory-cost-weight", "0.3") .put("network-cost-weight", "0.2") .put("enable-stats-calculator", "false") - .put("statistics-precalculation-for-pushdown.enabled", "true") + .put("statistics-precalculation-for-pushdown.enabled", "false") .put("collect-plan-statistics-for-all-queries", "true") .put("optimizer.ignore-stats-calculator-failures", "false") .put("optimizer.default-filter-factor-enabled", "true") @@ -146,7 +146,7 @@ public void testExplicitPropertyMappings() .setMemoryCostWeight(0.3) .setNetworkCostWeight(0.2) .setEnableStatsCalculator(false) - .setStatisticsPrecalculationForPushdownEnabled(true) + .setStatisticsPrecalculationForPushdownEnabled(false) .setCollectPlanStatisticsForAllQueries(true) .setIgnoreStatsCalculatorFailures(false) .setJoinDistributionType(BROADCAST) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 353b56ac51ab..59abd959ffdf 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -94,12 +94,14 @@ public class DefaultJdbcMetadata private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "_pfgnrtd_"; private final JdbcClient jdbcClient; + private final boolean precalculateStatisticsForPushdown; private final AtomicReference rollbackAction = new AtomicReference<>(); - public DefaultJdbcMetadata(JdbcClient jdbcClient) + public DefaultJdbcMetadata(JdbcClient jdbcClient, boolean precalculateStatisticsForPushdown) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); + this.precalculateStatisticsForPushdown = precalculateStatisticsForPushdown; } @Override @@ -210,8 +212,8 @@ public Optional> applyFilter(C return Optional.of( remainingExpression.isPresent() - ? new ConstraintApplicationResult<>(handle, remainingFilter, remainingExpression.get(), false) - : new ConstraintApplicationResult<>(handle, remainingFilter, false)); + ? new ConstraintApplicationResult<>(handle, remainingFilter, remainingExpression.get(), precalculateStatisticsForPushdown) + : new ConstraintApplicationResult<>(handle, remainingFilter, precalculateStatisticsForPushdown)); } private JdbcTableHandle flushAttributesAsQuery(ConnectorSession session, JdbcTableHandle handle) @@ -270,7 +272,7 @@ public Optional> applyProjecti assignment.getValue(), ((JdbcColumnHandle) assignment.getValue()).getColumnType())) .collect(toImmutableList()), - false)); + precalculateStatisticsForPushdown)); } @Override @@ -364,7 +366,7 @@ public Optional> applyAggrega handle.getAllReferencedTables(), nextSyntheticColumnId); - return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false)); + return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), precalculateStatisticsForPushdown)); } @Override @@ -449,7 +451,7 @@ public Optional> applyJoin( nextSyntheticColumnId), ImmutableMap.copyOf(newLeftColumns), ImmutableMap.copyOf(newRightColumns), - false)); + precalculateStatisticsForPushdown)); } private static Optional getVariableColumnHandle(Map assignments, ConnectorExpression expression) @@ -505,7 +507,7 @@ public Optional> applyLimit(Connect handle.getOtherReferencedTables(), handle.getNextSyntheticColumnId()); - return Optional.of(new LimitApplicationResult<>(handle, jdbcClient.isLimitGuaranteed(session), false)); + return Optional.of(new LimitApplicationResult<>(handle, jdbcClient.isLimitGuaranteed(session), precalculateStatisticsForPushdown)); } @Override @@ -554,7 +556,7 @@ public Optional> applyTopN( handle.getOtherReferencedTables(), handle.getNextSyntheticColumnId()); - return Optional.of(new TopNApplicationResult<>(sortedTableHandle, jdbcClient.isTopNGuaranteed(session), false)); + return Optional.of(new TopNApplicationResult<>(sortedTableHandle, jdbcClient.isTopNGuaranteed(session), precalculateStatisticsForPushdown)); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadataFactory.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadataFactory.java index f8905be684ce..3f7913f71087 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadataFactory.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadataFactory.java @@ -49,6 +49,6 @@ public JdbcMetadata create(JdbcTransactionHandle transaction) protected JdbcMetadata create(JdbcClient transactionCachingJdbcClient) { - return new DefaultJdbcMetadata(transactionCachingJdbcClient); + return new DefaultJdbcMetadata(transactionCachingJdbcClient, true); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownConfig.java new file mode 100644 index 000000000000..54e381d50714 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownConfig.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.units.DataSize; + +import java.util.Optional; + +public class JdbcJoinPushdownConfig +{ + private JoinPushdownStrategy joinPushdownStrategy = JoinPushdownStrategy.AUTOMATIC; + private Optional joinPushdownAutomaticMaxTableSize = Optional.empty(); + // Normally we would put 1.0 as a default value here to only allow joins which do not expand data to be pushed down. + // We use 1.25 to adjust for the fact that NDV estimations sometimes are off and joins which should be pushed down are not. + private double joinPushdownAutomaticMaxJoinToTablesRatio = 1.25; + + public JoinPushdownStrategy getJoinPushdownStrategy() + { + return joinPushdownStrategy; + } + + @Config("join-pushdown.strategy") + public JdbcJoinPushdownConfig setJoinPushdownStrategy(JoinPushdownStrategy joinPushdownStrategy) + { + this.joinPushdownStrategy = joinPushdownStrategy; + return this; + } + + public Optional getJoinPushdownAutomaticMaxTableSize() + { + return joinPushdownAutomaticMaxTableSize; + } + + @Config("experimental.join-pushdown.automatic.max-table-size") + @ConfigDescription("Maximum table size to be considered for join pushdown") + public JdbcJoinPushdownConfig setJoinPushdownAutomaticMaxTableSize(DataSize joinPushdownAutomaticMaxTableSize) + { + this.joinPushdownAutomaticMaxTableSize = Optional.ofNullable(joinPushdownAutomaticMaxTableSize); + return this; + } + + public double getJoinPushdownAutomaticMaxJoinToTablesRatio() + { + return joinPushdownAutomaticMaxJoinToTablesRatio; + } + + @Config("experimental.join-pushdown.automatic.max-join-to-tables-ratio") + @ConfigDescription("If estimated join output size is greater than or equal to ratio * sum of table sizes, then join pushdown will not be performed") + public JdbcJoinPushdownConfig setJoinPushdownAutomaticMaxJoinToTablesRatio(double joinPushdownAutomaticMaxJoinToTablesRatio) + { + this.joinPushdownAutomaticMaxJoinToTablesRatio = joinPushdownAutomaticMaxJoinToTablesRatio; + return this; + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSessionProperties.java new file mode 100644 index 000000000000..631d6c1d6ec0 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSessionProperties.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.session.PropertyMetadata; + +import javax.inject.Inject; + +import java.util.List; +import java.util.Optional; + +import static io.trino.plugin.base.session.PropertyMetadataUtil.dataSizeProperty; +import static io.trino.spi.session.PropertyMetadata.doubleProperty; +import static io.trino.spi.session.PropertyMetadata.enumProperty; + +public final class JdbcJoinPushdownSessionProperties + implements SessionPropertiesProvider +{ + public static final String JOIN_PUSHDOWN_STRATEGY = "join_pushdown_strategy"; + public static final String JOIN_PUSHDOWN_AUTOMATIC_MAX_TABLE_SIZE = "join_pushdown_automatic_max_table_size"; + public static final String JOIN_PUSHDOWN_AUTOMATIC_MAX_JOIN_TO_TABLES_RATIO = "join_pushdown_automatic_max_join_to_tables_ratio"; + + private final List> sessionProperties; + + @Inject + public JdbcJoinPushdownSessionProperties(JdbcJoinPushdownConfig joinPushdownConfig) + { + sessionProperties = ImmutableList.>builder() + .add(enumProperty( + JOIN_PUSHDOWN_STRATEGY, + "Join pushdown strategy", + JoinPushdownStrategy.class, + joinPushdownConfig.getJoinPushdownStrategy(), + false)) + .add(doubleProperty( + JOIN_PUSHDOWN_AUTOMATIC_MAX_JOIN_TO_TABLES_RATIO, + "If estimated join output size is greater than or equal to ratio * sum of table sizes, then join pushdown will not be performed", + joinPushdownConfig.getJoinPushdownAutomaticMaxJoinToTablesRatio(), + false)) + .add(dataSizeProperty( + JOIN_PUSHDOWN_AUTOMATIC_MAX_TABLE_SIZE, + "Maximum table size to be considered for join pushdown", + joinPushdownConfig.getJoinPushdownAutomaticMaxTableSize().orElse(null), + false)) + .build(); + } + + @Override + public List> getSessionProperties() + { + return sessionProperties; + } + + public static JoinPushdownStrategy getJoinPushdownStrategy(ConnectorSession session) + { + return session.getProperty(JOIN_PUSHDOWN_STRATEGY, JoinPushdownStrategy.class); + } + + public static Optional getJoinPushdownAutomaticMaxTableSize(ConnectorSession session) + { + return Optional.ofNullable(session.getProperty(JOIN_PUSHDOWN_AUTOMATIC_MAX_TABLE_SIZE, DataSize.class)); + } + + public static double getJoinPushdownAutomaticJoinToTablesRatio(ConnectorSession session) + { + return session.getProperty(JOIN_PUSHDOWN_AUTOMATIC_MAX_JOIN_TO_TABLES_RATIO, Double.class); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSupportModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSupportModule.java new file mode 100644 index 000000000000..cc0693352473 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownSupportModule.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.inject.Binder; +import io.airlift.configuration.AbstractConfigurationAwareModule; + +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider; + +/** + * A helper module for implementing cost-aware Join pushdown. It remains + * {@link io.trino.plugin.jdbc.JdbcClient}'s responsibility to provide cost-aware pushdown logic. + */ +public class JdbcJoinPushdownSupportModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(JdbcJoinPushdownConfig.class); + bindSessionPropertiesProvider(binder, JdbcJoinPushdownSessionProperties.class); + + configBinder(binder).bindConfigDefaults(JdbcMetadataConfig.class, config -> config.setJoinPushdownEnabled(true)); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownUtil.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownUtil.java new file mode 100644 index 000000000000..87deefe3f377 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinPushdownUtil.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.airlift.log.Logger; +import io.airlift.units.DataSize; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; + +import java.util.Optional; +import java.util.function.Supplier; + +import static io.trino.plugin.jdbc.JdbcJoinPushdownSessionProperties.getJoinPushdownAutomaticJoinToTablesRatio; +import static io.trino.plugin.jdbc.JdbcJoinPushdownSessionProperties.getJoinPushdownAutomaticMaxTableSize; +import static io.trino.plugin.jdbc.JdbcJoinPushdownSessionProperties.getJoinPushdownStrategy; +import static java.lang.String.format; + +public final class JdbcJoinPushdownUtil +{ + private static final Logger LOG = Logger.get(JdbcJoinPushdownUtil.class); + + private JdbcJoinPushdownUtil() {} + + public static Optional implementJoinCostAware( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + JoinStatistics statistics, + Supplier> delegate) + { + // Calling out to super.implementJoin() before shouldImplementJoinPushdownBasedOnStats() so we can return quickly if we know we should not push down, even without + // analyzing table statistics. Getting table statistics can be expensive and we want to avoid that if possible. + Optional result = delegate.get(); + if (result.isEmpty()) { + return Optional.empty(); + } + + JoinPushdownStrategy joinPushdownStrategy = getJoinPushdownStrategy(session); + switch (joinPushdownStrategy) { + case EAGER: + return result; + + case AUTOMATIC: + if (shouldPushDownJoinCostAware(session, joinType, leftSource, rightSource, statistics)) { + return result; + } + return Optional.empty(); + } + throw new IllegalArgumentException("Unsupported joinPushdownStrategy: " + joinPushdownStrategy); + } + + /** + * Common implementation of AUTOMATIC join pushdown strategy to by used in SEP Jdbc connectors + */ + public static boolean shouldPushDownJoinCostAware( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + JoinStatistics statistics) + { + long maxTableSizeBytes = getJoinPushdownAutomaticMaxTableSize(session).map(DataSize::toBytes).orElse(Long.MAX_VALUE); + + String joinSignature = ""; + if (LOG.isDebugEnabled()) { + joinSignature = diagnosticsJoinSignature(joinType, leftSource, rightSource); + } + + if (statistics.getLeftStatistics().isEmpty()) { + logNoPushdown(joinSignature, "left stats empty"); + return false; + } + + double leftDataSize = statistics.getLeftStatistics().get().getDataSize(); + if (leftDataSize > maxTableSizeBytes) { + logNoPushdown(joinSignature, () -> "left size " + leftDataSize + " > " + maxTableSizeBytes); + return false; + } + + if (statistics.getRightStatistics().isEmpty()) { + logNoPushdown(joinSignature, "right stats empty"); + return false; + } + + double rightDataSize = statistics.getRightStatistics().get().getDataSize(); + if (rightDataSize > maxTableSizeBytes) { + logNoPushdown(joinSignature, () -> "right size " + rightDataSize + " > " + maxTableSizeBytes); + return false; + } + + if (statistics.getJoinStatistics().isEmpty()) { + logNoPushdown(joinSignature, "join stats empty"); + return false; + } + + double joinDataSize = statistics.getJoinStatistics().get().getDataSize(); + if (joinDataSize < getJoinPushdownAutomaticJoinToTablesRatio(session) * (leftDataSize + rightDataSize)) { + // This is poor man's estimation if it makes more sense to perform join in source database or SEP. + // The assumption here is that cost of performing join in source database is less than or equal to cost of join in SEP. + // We resolve tie for pessimistic case (both join costs equal) on cost of sending the data from source database to SEP. + LOG.debug("triggering join pushdown for %s", joinSignature); + return true; + } + logNoPushdown(joinSignature, () -> + "joinDataSize " + joinDataSize + " >= " + + getJoinPushdownAutomaticJoinToTablesRatio(session) + + " * (leftDataSize " + leftDataSize + + " + rightDataSize " + rightDataSize + ") = " + getJoinPushdownAutomaticJoinToTablesRatio(session) * (leftDataSize + rightDataSize)); + return false; + } + + private static String diagnosticsJoinSignature( + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource) + { + return format("%s JOIN(%s; %s)", joinType, leftSource.getQuery(), rightSource.getQuery()); + } + + private static void logNoPushdown(String joinSignature, String reason) + { + logNoPushdown(joinSignature, () -> reason); + } + + private static void logNoPushdown(String joinSignature, Supplier reason) + { + if (LOG.isDebugEnabled()) { + LOG.debug("skipping join pushdown for %s; %s", joinSignature, reason.get()); + } + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcStatisticsConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcStatisticsConfig.java new file mode 100644 index 000000000000..70176b591cfb --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcStatisticsConfig.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.airlift.configuration.Config; + +public class JdbcStatisticsConfig +{ + private boolean enabled = true; + + public boolean isEnabled() + { + return enabled; + } + + @Config("statistics.enabled") + public JdbcStatisticsConfig setEnabled(boolean enabled) + { + this.enabled = enabled; + return this; + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JoinPushdownStrategy.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JoinPushdownStrategy.java new file mode 100644 index 000000000000..7c0c1e7e993d --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JoinPushdownStrategy.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +public enum JoinPushdownStrategy +{ + /** + * Try to push all joins except cross-joins to connector. + */ + EAGER, + /** + * Determine automatically if push join to connector based on table statistics. + * Do not perform join in absence of table statistics. + */ + AUTOMATIC, +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java new file mode 100644 index 000000000000..eeee51c6f31c --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java @@ -0,0 +1,354 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.Session; +import io.trino.SystemSessionProperties; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.sql.TestTable; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.google.common.base.Throwables.getStackTraceAsString; +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class BaseJdbcTableStatisticsTest + extends AbstractTestQueryFramework +{ + // Currently this class serves as a common "interface" to define cases that should be covered. + // TODO extend it to provide reusable blocks to reduce boiler-plate. + + @BeforeClass + public void setUpTables() + { + setUpTableFromTpch("region"); + setUpTableFromTpch("nation"); + } + + private void setUpTableFromTpch(String tableName) + { + // Create the table. Some subclasses use shared resources, so the table may actually exist. + computeActual("CREATE TABLE IF NOT EXISTS " + tableName + " AS TABLE tpch.tiny." + tableName); + // Sanity check on state of the table in case it already existed. + assertThat(query("SELECT count(*) FROM " + tableName)) + .matches("SELECT count(*) FROM tpch.tiny." + tableName); + + try { + gatherStats(tableName); + } + catch (Exception e) { + // gatherStats does not have to be idempotent, so we need to ignore certain errors + if (getStackTraceAsString(e).toLowerCase(Locale.ENGLISH).contains( + // wording comes from Synapse + "there are already statistics on table")) { + // ignore + } + else { + throw e; + } + } + } + + @Test + public abstract void testNotAnalyzed(); + + @Test + public abstract void testBasic(); + + @Test + public void testEmptyTable() + { + String tableName = "test_stats_table_empty_" + randomTableSuffix(); + computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName)); + try { + gatherStats(tableName); + checkEmptyTableStats(tableName); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + protected void checkEmptyTableStats(String tableName) + { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', 0, 0, 1, null, null, null)," + + "('custkey', 0, 0, 1, null, null, null)," + + "('orderpriority', 0, 0, 1, null, null, null)," + + "('comment', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 0, null, null)"); + } + + @Test + public abstract void testAllNulls(); + + @Test + public abstract void testNullsFraction(); + + @Test + public abstract void testAverageColumnLength(); + + @Test + public abstract void testPartitionedTable(); + + @Test + public abstract void testView(); + + @Test + public abstract void testMaterializedView(); + + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public abstract void testCaseColumnNames(String tableName); + + @DataProvider + public Object[][] testCaseColumnNamesDataProvider() + { + return new Object[][] { + {"TEST_STATS_MIXED_UNQUOTED_UPPER"}, + {"test_stats_mixed_unquoted_lower"}, + {"test_stats_mixed_uNQuoTeD_miXED"}, + {"\"TEST_STATS_MIXED_QUOTED_UPPER\""}, + {"\"test_stats_mixed_quoted_lower\""}, + {"\"test_stats_mixed_QuoTeD_miXED\""}, + }; + } + + @Test + public abstract void testNumericCornerCases(); + + @Test + public void testStatsWithPredicatePushdown() + { + // Predicate on a numeric column. Should be eligible for pushdown. + String query = "SELECT * FROM nation WHERE regionkey = 1"; + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(query)).isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('nationkey', 5e0, 0e0, null)," + + "('name', 5e0, 0e0, null)," + + "('regionkey', 1e0, 0e0, null)," + + "('comment', 5e0, 0e0, null)," + + "(null, null, null, 5e0)"); + } + + @Test + public void testStatsWithVarcharPredicatePushdown() + { + // Predicate on a varchar column. May or may not be pushed down, may or may not be subsumed. + assertThat(query("SHOW STATS FOR (SELECT * FROM nation WHERE name = 'PERU')")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('nationkey', 1e0, 0e0, null)," + + "('name', 1e0, 0e0, null)," + + "('regionkey', 1e0, 0e0, null)," + + "('comment', 1e0, 0e0, null)," + + "(null, null, null, 1e0)"); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "varchar_duplicates", + // each letter A-E repeated 5 times + " AS SELECT nationkey, chr(codepoint('A') + nationkey / 5) fl FROM tpch.tiny.nation")) { + gatherStats(table.getName()); + + assertThat(query("SHOW STATS FOR (SELECT * FROM " + table.getName() + " WHERE fl = 'B')")) + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('nationkey', 5e0, 0e0, null)," + + "('fl', 1e0, 0e0, null)," + + "(null, null, null, 5e0)"); + } + } + + /** + * Verify that when {@value SystemSessionProperties#STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED} is disabled, + * the connector still returns reasonable statistics. + */ + @Test + public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() + { + // Predicate on a numeric column. Should be eligible for pushdown. + String query = "SELECT * FROM nation WHERE regionkey = 1"; + Session session = Session.builder(getSession()) + .setSystemProperty(SystemSessionProperties.STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED, "false") + .build(); + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(session, query)).isFullyPushedDown(); + + assertThat(query(session, "SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('nationkey', 25e0, 0e0, null)," + + "('name', 25e0, 0e0, null)," + + "('regionkey', 5e0, 0e0, null)," + + "('comment', 25e0, 0e0, null)," + + "(null, null, null, 25e0)"); + } + + @Test + public void testStatsWithLimitPushdown() + { + // Just limit, should be eligible for pushdown. + String query = "SELECT regionkey, nationkey FROM nation LIMIT 2"; + + // Verify query can be pushed down, that's the situation we want to test for. + // it's important that we test with LIMIT value smaller than table row count, hence need to skip results check + assertThat(query(query)).skipResultsCorrectnessCheckForPushdown().isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 2e0, 0e0, null)," + + "('nationkey', 2e0, 0e0, null)," + + "(null, null, null, 2e0)"); + } + + @Test + public void testStatsWithTopNPushdown() + { + // TopN on a numeric column, should be eligible for pushdown. + String query = "SELECT regionkey, nationkey FROM nation ORDER BY regionkey LIMIT 2"; + + // Verify query can be pushed down, that's the situation we want to test for. + // it's important that we test with LIMIT value smaller than table row count and we intentionally sort on a non-unique regionkey, hence need to skip results check. + assertThat(query(query)).skipResultsCorrectnessCheckForPushdown().isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 2e0, 0e0, null)," + + "('nationkey', 2e0, 0e0, null)," + + "(null, null, null, 2e0)"); + } + + @Test + public void testStatsWithDistinctPushdown() + { + // Just distinct, should be eligible for pushdown. + String query = "SELECT DISTINCT regionkey FROM nation"; + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(query)).isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 5e0, 0e0, null)," + + "(null, null, null, 5e0)"); + } + + @Test + public void testStatsWithDistinctLimitPushdown() + { + // Distinct with limit (DistinctLimitNode), should be eligible for pushdown. + String query = "SELECT DISTINCT regionkey FROM nation LIMIT 3"; + + // Verify query can be pushed down, that's the situation we want to test for. + // it's important that we test with LIMIT value smaller than count(DISTINCT regionkey), hence need to skip results check + assertThat(query(query)).skipResultsCorrectnessCheckForPushdown().isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 3e0, 0e0, null)," + + "(null, null, null, 3e0)"); + } + + @Test + public void testStatsWithAggregationPushdown() + { + // Simple aggregation, should be eligible for pushdown. + String query = "SELECT regionkey, max(nationkey) max_nationkey, count(*) c FROM nation GROUP BY regionkey"; + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(query)).isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 5e0, 0e0, null)," + + "('max_nationkey', null, null, null)," + + "('c', null, null, null)," + + "(null, null, null, 5e0)"); + } + + @Test + public void testStatsWithSimpleJoinPushdown() + { + // Simple filtering join with no predicates, should be eligible for pushdown. + String query = "SELECT n.name n_name FROM nation n JOIN region r ON n.nationkey = r.regionkey"; + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(query)).isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('n_name', 5e0, 0e0, null)," + + "(null, null, null, 5e0)"); + } + + @Test + public void testStatsWithJoinPushdown() + { + // Simple join with heavily filtered side, should be eligible for pushdown. + String query = "SELECT r.regionkey regionkey, r.name r_name, n.name n_name FROM region r JOIN nation n ON r.regionkey = n.regionkey WHERE n.nationkey = 5"; + + // Verify query can be pushed down, that's the situation we want to test for. + assertThat(query(query)).isFullyPushedDown(); + + assertThat(query("SHOW STATS FOR (" + query + ")")) + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .projected(0, 2, 3, 4) + .skippingTypesCheck() + .matches("VALUES " + + "('regionkey', 1e0, 0e0, null)," + + "('r_name', 1e0, 0e0, null)," + + "('n_name', 1e0, 0e0, null)," + + "(null, null, null, 1e0)"); + } + + protected abstract void gatherStats(String tableName); +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java index 0aca2c89214b..0ad21f947292 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcMetadata.java @@ -67,7 +67,7 @@ public void setUp() throws Exception { database = new TestingDatabase(); - metadata = new DefaultJdbcMetadata(new GroupingSetsEnabledJdbcClient(database.getJdbcClient())); + metadata = new DefaultJdbcMetadata(new GroupingSetsEnabledJdbcClient(database.getJdbcClient()), false); tableHandle = metadata.getTableHandle(SESSION, new SchemaTableName("example", "numbers")); } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcJoinPushdownConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcJoinPushdownConfig.java new file mode 100644 index 000000000000..e51a320b88d6 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcJoinPushdownConfig.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.trino.plugin.jdbc.JoinPushdownStrategy.AUTOMATIC; +import static io.trino.plugin.jdbc.JoinPushdownStrategy.EAGER; + +public class TestJdbcJoinPushdownConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(JdbcJoinPushdownConfig.class) + .setJoinPushdownStrategy(AUTOMATIC) + .setJoinPushdownAutomaticMaxTableSize(null) + .setJoinPushdownAutomaticMaxJoinToTablesRatio(1.25)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("join-pushdown.strategy", "EAGER") + .put("experimental.join-pushdown.automatic.max-table-size", "10MB") + .put("experimental.join-pushdown.automatic.max-join-to-tables-ratio", "2.0") + .buildOrThrow(); + + JdbcJoinPushdownConfig expected = new JdbcJoinPushdownConfig() + .setJoinPushdownStrategy(EAGER) + .setJoinPushdownAutomaticMaxTableSize(DataSize.valueOf("10MB")) + .setJoinPushdownAutomaticMaxJoinToTablesRatio(2.0); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcStatisticsConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcStatisticsConfig.java new file mode 100644 index 000000000000..dd527695d605 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcStatisticsConfig.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestJdbcStatisticsConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(JdbcStatisticsConfig.class) + .setEnabled(true)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("statistics.enabled", "false") + .buildOrThrow(); + + JdbcStatisticsConfig expected = new JdbcStatisticsConfig() + .setEnabled(false); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-memory/src/test/java/io/trino/plugin/memory/TestMemoryTableStatistics.java b/plugin/trino-memory/src/test/java/io/trino/plugin/memory/TestMemoryTableStatistics.java index f8709914306f..d79b4e80d9f5 100644 --- a/plugin/trino-memory/src/test/java/io/trino/plugin/memory/TestMemoryTableStatistics.java +++ b/plugin/trino-memory/src/test/java/io/trino/plugin/memory/TestMemoryTableStatistics.java @@ -16,10 +16,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; +import io.trino.sql.planner.OptimizerConfig; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; import org.testng.annotations.Test; +import static com.google.common.base.Verify.verify; import static io.trino.SystemSessionProperties.STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED; import static io.trino.plugin.memory.MemoryQueryRunner.createMemoryQueryRunner; import static io.trino.tpch.TpchTable.NATION; @@ -85,16 +87,16 @@ public void testStatsWithLimitPushdown() assertThat(query(query)).skipResultsCorrectnessCheckForPushdown().isFullyPushedDown(); assertQuery( + disableStatisticsPrecalculation(getSession()), "SHOW STATS FOR (" + query + ")", "VALUES " + "('nationkey', null, null, null, null, null, null)," + "('name', null, null, null, null, null, null)," + "('regionkey', null, null, null, null, null, null)," + "('comment', null, null, null, null, null, null)," + - "(null, null, null, null, 25, null, null)"); // TODO should be 3 + "(null, null, null, null, 25, null, null)"); assertQuery( - enableStatisticsPrecalculation(getSession()), "SHOW STATS FOR (" + query + ")", "VALUES " + "('nationkey', null, null, null, null, null, null)," + @@ -111,16 +113,16 @@ public void testStatsWithSamplePushdown() assertThat(query(query)).skipResultsCorrectnessCheckForPushdown().isFullyPushedDown(); assertQuery( + disableStatisticsPrecalculation(getSession()), "SHOW STATS FOR (" + query + ")", "VALUES " + "('nationkey', null, null, null, null, null, null)," + "('name', null, null, null, null, null, null)," + "('regionkey', null, null, null, null, null, null)," + "('comment', null, null, null, null, null, null)," + - "(null, null, null, null, 25, null, null)"); // TODO should be 12.5 + "(null, null, null, null, 25, null, null)"); assertQuery( - enableStatisticsPrecalculation(getSession()), "SHOW STATS FOR (" + query + ")", "VALUES " + "('nationkey', null, null, null, null, null, null)," + @@ -130,10 +132,11 @@ public void testStatsWithSamplePushdown() "(null, null, null, null, 12.5, null, null)"); } - private Session enableStatisticsPrecalculation(Session base) + private Session disableStatisticsPrecalculation(Session base) { + verify(new OptimizerConfig().isStatisticsPrecalculationForPushdownEnabled(), "this assumes precalculation is enabled by default"); return Session.builder(base) - .setSystemProperty(STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED, "true") + .setSystemProperty(STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED, "false") .build(); } } diff --git a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixMetadata.java b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixMetadata.java index 900cbf232406..0e40b90147d4 100644 --- a/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixMetadata.java +++ b/plugin/trino-phoenix/src/main/java/io/trino/plugin/phoenix/PhoenixMetadata.java @@ -67,7 +67,7 @@ public class PhoenixMetadata @Inject public PhoenixMetadata(PhoenixClient phoenixClient, IdentifierMapping identifierMapping) { - super(phoenixClient); + super(phoenixClient, false); this.phoenixClient = requireNonNull(phoenixClient, "phoenixClient is null"); this.identifierMapping = requireNonNull(identifierMapping, "identifierMapping is null"); } diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java index 618403de4111..cb340f739a40 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java @@ -72,7 +72,7 @@ public class PhoenixMetadata @Inject public PhoenixMetadata(PhoenixClient phoenixClient, IdentifierMapping identifierMapping) { - super(phoenixClient); + super(phoenixClient, false); this.phoenixClient = requireNonNull(phoenixClient, "phoenixClient is null"); this.identifierMapping = requireNonNull(identifierMapping, "identifierMapping is null"); } diff --git a/plugin/trino-postgresql/pom.xml b/plugin/trino-postgresql/pom.xml index d15016581cb9..1d9e6d211e96 100644 --- a/plugin/trino-postgresql/pom.xml +++ b/plugin/trino-postgresql/pom.xml @@ -73,6 +73,11 @@ joda-time + + org.jdbi + jdbi3-core + + org.postgresql postgresql diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 81f9bc43e798..79e33f2b59c9 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -32,6 +32,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongReadFunction; @@ -75,10 +76,16 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -94,6 +101,8 @@ import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarcharType; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; import org.postgresql.core.TypeInfo; import org.postgresql.jdbc.PgConnection; @@ -125,7 +134,9 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse; import static io.trino.plugin.base.util.JsonTypeUtil.toJsonValue; @@ -134,6 +145,7 @@ import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; import static io.trino.plugin.jdbc.PredicatePushdownController.FULL_PUSHDOWN; @@ -216,6 +228,7 @@ import static java.math.RoundingMode.UNNECESSARY; import static java.sql.DatabaseMetaData.columnNoNulls; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.joining; public class PostgreSqlClient @@ -238,6 +251,7 @@ public class PostgreSqlClient private final Type uuidType; private final MapType varcharMapType; private final List tableTypes; + private final boolean statisticsEnabled; private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; @@ -263,6 +277,7 @@ public class PostgreSqlClient public PostgreSqlClient( BaseJdbcConfig config, PostgreSqlConfig postgreSqlConfig, + JdbcStatisticsConfig statisticsConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, @@ -280,6 +295,8 @@ public PostgreSqlClient( } this.tableTypes = tableTypes.build(); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); + this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) // TODO allow all comparison operators for numeric types @@ -828,6 +845,129 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) } } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return readTableStatistics(session, handle); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + Optional optionalRowCount = readRowCountTableStat(statisticsDao, table); + if (optionalRowCount.isEmpty()) { + // Table not found + return TableStatistics.empty(); + } + long rowCount = optionalRowCount.get(); + + TableStatistics.Builder tableStatistics = TableStatistics.builder(); + tableStatistics.setRowCount(Estimate.of(rowCount)); + + if (rowCount == 0) { + return tableStatistics.build(); + } + + Map columnStatistics = statisticsDao.getColumnStatistics(table.getSchemaName(), table.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::getColumnName, identity())); + + for (JdbcColumnHandle column : this.getColumns(session, table)) { + ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); + if (result == null) { + continue; + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(result.getNullsFraction() + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDistinctValuesCount(result.getDistinctValuesIndicator() + .map(distinctValuesIndicator -> { + if (distinctValuesIndicator >= 0.0) { + return distinctValuesIndicator; + } + return -distinctValuesIndicator * rowCount; + }) + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDataSize(result.getAverageColumnLength() + .flatMap(averageColumnLength -> + result.getNullsFraction().map(nullsFraction -> + Estimate.of(1.0 * averageColumnLength * rowCount * (1 - nullsFraction)))) + .orElseGet(Estimate::unknown)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Optional readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table) + { + Optional rowCount = statisticsDao.getRowCountFromPgClass(table.getSchemaName(), table.getTableName()); + if (rowCount.isEmpty()) { + // Table not found + return Optional.empty(); + } + + if (statisticsDao.isPartitionedTable(table.getSchemaName(), table.getTableName())) { + Optional partitionedTableRowCount = statisticsDao.getRowCountPartitionedTableFromPgClass(table.getSchemaName(), table.getTableName()); + if (partitionedTableRowCount.isPresent()) { + return partitionedTableRowCount; + } + + return statisticsDao.getRowCountPartitionedTableFromPgStats(table.getSchemaName(), table.getTableName()); + } + + if (rowCount.get() == 0) { + // `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT) + // `pg_stat_all_tables.n_live_tup` can be way off, so we use it only as a fallback + rowCount = statisticsDao.getRowCountFromPgStat(table.getSchemaName(), table.getTableName()); + } + + return rowCount; + } + + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { @@ -1266,4 +1406,131 @@ private ColumnMapping uuidColumnMapping() (resultSet, columnIndex) -> javaUuidToTrinoUuid((UUID) resultSet.getObject(columnIndex)), uuidWriteFunction()); } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Optional getRowCountFromPgClass(String schema, String tableName) + { + return handle.createQuery("" + + "SELECT reltuples " + + "FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountFromPgStat(String schema, String tableName) + { + return handle.createQuery("SELECT n_live_tup FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountPartitionedTableFromPgClass(String schema, String tableName) + { + return handle.createQuery("" + + "SELECT SUM(child.reltuples) " + + "FROM pg_inherits " + + "JOIN pg_class parent ON pg_inherits.inhparent = parent.oid " + + "JOIN pg_class child ON pg_inherits.inhrelid = child.oid " + + "JOIN pg_namespace parent_ns ON parent_ns.oid = parent.relnamespace " + + "JOIN pg_namespace child_ns ON child_ns.oid = child.relnamespace " + + "WHERE parent.oid = :schema_table_name::regclass") + .bind("schema_table_name", format("%s.%s", schema, tableName)) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountPartitionedTableFromPgStats(String schema, String tableName) + { + return handle.createQuery("" + + "SELECT SUM(stat.n_live_tup) " + + "FROM pg_inherits " + + "JOIN pg_class parent ON pg_inherits.inhparent = parent.oid " + + "JOIN pg_class child ON pg_inherits.inhrelid = child.oid " + + "JOIN pg_namespace parent_ns ON parent_ns.oid = parent.relnamespace " + + "JOIN pg_namespace child_ns ON child_ns.oid = child.relnamespace " + + "JOIN pg_stat_all_tables stat ON stat.schemaname = child_ns.nspname AND stat.relname = child.relname " + + "WHERE parent.oid = :schema_table_name::regclass") + .bind("schema_table_name", format("%s.%s", schema, tableName)) + .mapTo(Long.class) + .findOne(); + } + + List getColumnStatistics(String schema, String tableName) + { + return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .map((rs, ctx) -> new ColumnStatisticsResult( + requireNonNull(rs.getString("attname"), "attname is null"), + Optional.ofNullable(rs.getObject("null_frac", Float.class)), + Optional.ofNullable(rs.getObject("n_distinct", Float.class)), + Optional.ofNullable(rs.getObject("avg_width", Integer.class)))) + .list(); + } + + boolean isPartitionedTable(String schema, String tableName) + { + return handle.createQuery("" + + "SELECT true " + + "FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name " + + "AND relkind = 'p'") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Boolean.class) + .findOne() + .orElse(false); + } + } + + private static class ColumnStatisticsResult + { + private final String columnName; + private final Optional nullsFraction; + private final Optional distinctValuesIndicator; + private final Optional averageColumnLength; + + public ColumnStatisticsResult(String columnName, Optional nullsFraction, Optional distinctValuesIndicator, Optional averageColumnLength) + { + this.columnName = columnName; + this.nullsFraction = nullsFraction; + this.distinctValuesIndicator = distinctValuesIndicator; + this.averageColumnLength = averageColumnLength; + } + + public String getColumnName() + { + return columnName; + } + + public Optional getNullsFraction() + { + return nullsFraction; + } + + public Optional getDistinctValuesIndicator() + { + return distinctValuesIndicator; + } + + public Optional getAverageColumnLength() + { + return averageColumnLength; + } + } } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java index e75fbea4d7f2..d1056ed7ae84 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java @@ -24,6 +24,8 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteQueryCancellationModule; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -41,9 +43,11 @@ public void setup(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(PostgreSqlClient.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(PostgreSqlConfig.class); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); bindSessionPropertiesProvider(binder, PostgreSqlSessionProperties.class); newOptionalBinder(binder, QueryBuilder.class).setBinding().to(CollationAwareQueryBuilder.class).in(Scopes.SINGLETON); install(new DecimalModule()); + install(new JdbcJoinPushdownSupportModule()); install(new RemoteQueryCancellationModule()); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index c9dfbccf621c..fcfe35f78a58 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -20,6 +20,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; @@ -93,6 +94,7 @@ public class TestPostgreSqlClient private static final JdbcClient JDBC_CLIENT = new PostgreSqlClient( new BaseJdbcConfig(), new PostgreSqlConfig(), + new JdbcStatisticsConfig(), session -> { throw new UnsupportedOperationException(); }, new DefaultQueryBuilder(), TESTING_TYPE_MANAGER, diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 0260222456b2..1ed767b6578a 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -947,4 +947,13 @@ protected TestView createSleepingView(Duration minimalQueryDuration) long secondsToSleep = round(minimalQueryDuration.convertTo(SECONDS).getValue() + 1); return new TestView(onRemoteDatabase(), "test_sleeping_view", format("SELECT 1 FROM pg_sleep(%d)", secondsToSleep)); } + + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java new file mode 100644 index 000000000000..cc54e0dbc95d --- /dev/null +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlTableStatistics.java @@ -0,0 +1,466 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.postgresql; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Properties; + +import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.lang.String.format; +import static java.util.stream.Collectors.joining; + +public class TestPostgreSqlTableStatistics + extends BaseJdbcTableStatisticsTest +{ + private TestingPostgreSqlServer postgreSqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + postgreSqlServer = closeAfterClass(new TestingPostgreSqlServer()); + return createPostgreSqlQueryRunner( + postgreSqlServer, + ImmutableMap.of(), + ImmutableMap.builder() + .put("connection-url", postgreSqlServer.getJdbcUrl()) + .put("connection-user", postgreSqlServer.getUser()) + .put("connection-password", postgreSqlServer.getPassword()) + .put("case-insensitive-name-matching", "true") + .buildOrThrow(), + ImmutableList.of(ORDERS)); + } + + @Override + @Test(invocationCount = 10, successPercentage = 50) // PostgreSQL can auto-analyze data before we SHOW STATS + public void testNotAnalyzed() + { + String tableName = "test_stats_not_analyzed"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testBasic() + { + String tableName = "test_stats_orders"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 135000, 5, 0, null, null, null)," + + "('clerk', 240000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 735000, 14995, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testAllNulls() + { + String tableName = "test_stats_table_all_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName)); + try { + computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL", tableName)); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', 0, 0, 1, null, null, null)," + + "('custkey', 0, 0, 1, null, null, null)," + + "('orderpriority', 0, 0, 1, null, null, null)," + + "('comment', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 3, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNullsFraction() + { + String tableName = "test_stats_table_with_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + assertUpdate("" + + "CREATE TABLE " + tableName + " AS " + + "SELECT " + + " orderkey, " + + " if(orderkey % 3 = 0, NULL, custkey) custkey, " + + " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " + + "FROM tpch.tiny.orders", + 15000); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0.3333333333333333, null, null, null)," + + "('orderpriority', 108000, 5, 0.2, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testAverageColumnLength() + { + String tableName = "test_stats_table_avg_col_len"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual("" + + "CREATE TABLE " + tableName + " AS SELECT " + + " orderkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " if(orderkey = 1, '0123456789', NULL) single_10v_value, " + + " if(orderkey % 2 = 0, '0123456789', NULL) half_10v_value, " + + " if(orderkey % 2 = 0, CAST((1000000 - orderkey) * (1000000 - orderkey) AS varchar(20)), NULL) half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM tpch.tiny.orders " + + "ORDER BY orderkey LIMIT 100"); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 100, 0, null, null, null)," + + "('v3_in_3', 400, 1, 0, null, null, null)," + + "('v3_in_42', 400, 1, 0, null, null, null)," + + "('single_10v_value', 11, 1, 0.99, null, null, null)," + + "('half_10v_value', 550, 1, 0.5, null, null, null)," + + "('half_distinct_20v_value', 650, 50, 0.5, null, null, null)," + + "('all_nulls', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 100, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testPartitionedTable() + { + String tableName = "test_stats_orders_part"; + String firstPartitionedTable = "test_stats_orders_part_1990_1994"; + String secondPartitionedTable = "test_stats_orders_part_1995_1999"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + assertUpdate("DROP TABLE IF EXISTS " + firstPartitionedTable); + assertUpdate("DROP TABLE IF EXISTS " + secondPartitionedTable); + + executeInPostgres("CREATE TABLE " + tableName + " (LIKE orders) PARTITION BY RANGE(orderdate)"); + try { + // Verify the behavior when a partitioned table doesn't have child tables + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + + // Create child tables + executeInPostgres(format("CREATE TABLE %s PARTITION OF %s FOR VALUES FROM ('1990-01-01') TO ('1995-01-01')", firstPartitionedTable, tableName)); + executeInPostgres(format("CREATE TABLE %s PARTITION OF %s FOR VALUES FROM ('1995-01-01') TO ('1999-12-31')", secondPartitionedTable, tableName)); + executeInPostgres(format("INSERT INTO %s SELECT * FROM orders WHERE orderdate <= '1994-12-31'", firstPartitionedTable)); + executeInPostgres(format("INSERT INTO %s SELECT * FROM orders WHERE orderdate >= '1995-01-01'", secondPartitionedTable)); + + // Analyzing child tables doesn't expose the statistics + gatherStats(firstPartitionedTable); + gatherStats(secondPartitionedTable); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + + // Analyzing parent table exposes the statistics + gatherStatsPartitionedTable(tableName, ImmutableList.of(firstPartitionedTable, secondPartitionedTable)); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 135000, 5, 0, null, null, null)," + + "('clerk', 240000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 735000, 14995, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); // This removes child tables too + } + } + + @Override + @Test + public void testView() + { + String tableName = "test_stats_view"; + executeInPostgres("CREATE OR REPLACE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders"); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + // It's not possible to ANALYZE a VIEW in PostgreSQL + } + finally { + executeInPostgres("DROP VIEW " + tableName); + } + } + + @Override + @Test + public void testMaterializedView() + { + String tableName = "test_stats_materialized_view"; + executeInPostgres("DROP MATERIALIZED VIEW IF EXISTS " + tableName); + executeInPostgres("" + + "CREATE MATERIALIZED VIEW " + tableName + " " + + "AS SELECT orderkey, custkey, orderpriority, comment FROM orders"); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderpriority', 135000, 5, 0, null, null, null)," + + "('comment', 735000, 14995, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + executeInPostgres("DROP MATERIALIZED VIEW " + tableName); + } + } + + @Override + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + executeInPostgres("" + + "CREATE TABLE " + tableName + " " + + "AS SELECT " + + " orderkey AS CASE_UNQUOTED_UPPER, " + + " custkey AS case_unquoted_lower, " + + " orderstatus AS cASe_uNQuoTeD_miXED, " + + " totalprice AS \"CASE_QUOTED_UPPER\", " + + " orderdate AS \"case_quoted_lower\"," + + " orderpriority AS \"CasE_QuoTeD_miXED\" " + + "FROM orders"); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('case_unquoted_upper', null, 15000, 0, null, null, null)," + + "('case_unquoted_lower', null, 1000, 0, null, null, null)," + + "('case_unquoted_mixed', 30000, 3, 0, null, null, null)," + + "('case_quoted_upper', null, 14996, 0, null, null, null)," + + "('case_quoted_lower', null, 2401, 0, null, null, null)," + + "('case_quoted_mixed', 135000, 5, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + executeInPostgres("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) + .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) + .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) + .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) + .put("nans_only double", List.of("nan()", "nan()")) + .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + gatherStats(table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + "('only_negative_infinity', null, 1, 0, null, null, null)," + + "('only_positive_infinity', null, 1, 0, null, null, null)," + + "('mixed_infinities', null, 2, 0, null, null, null)," + + "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + + "('nans_only', null, 1.0, 0.5, null, null, null)," + + "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "(null, null, null, null, 4, null, null)"); + } + } + + private void executeInPostgres(String sql) + { + inPostgres(handle -> handle.execute(sql)); + } + + @Override + protected void gatherStats(String tableName) + { + inPostgres(handle -> { + handle.execute("ANALYZE " + tableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + tableName) + .mapTo(Long.class) + .findOnly(); + long estimatedCount = handle.createQuery(format("SELECT reltuples FROM pg_class WHERE oid = '%s'::regclass::oid", tableName)) + .mapTo(Long.class) + .findOnly(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE " + tableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + private void gatherStatsPartitionedTable(String parentTableName, List childTableNames) + { + String parameter = childTableNames.stream() + .map(tableName -> format("'%s'::regclass::oid", tableName)) + .collect(joining(", ")); + inPostgres(handle -> { + handle.execute("ANALYZE " + parentTableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + parentTableName) + .mapTo(Long.class) + .findOnly(); + long estimatedCount = handle.createQuery(format("SELECT SUM(reltuples) FROM pg_class WHERE oid IN (%s)", parameter)) + .mapTo(Long.class) + .findOnly(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE " + parentTableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + private void inPostgres(HandleConsumer callback) + throws E + { + Properties properties = new Properties(); + properties.setProperty("currentSchema", "tpch"); + properties.setProperty("user", postgreSqlServer.getUser()); + properties.setProperty("password", postgreSqlServer.getPassword()); + Jdbi.create(postgreSqlServer.getJdbcUrl(), properties) + .useHandle(callback); + } +}