diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index 6698d7625a99..076abcc86a4d 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -409,7 +409,19 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n SymbolStatsEstimate leftStats = getExpressionStats(left); Optional leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty(); if (isEffectivelyLiteral(right)) { - OptionalDouble literal = doubleValueFromLiteral(getType(left), right); + Type type = getType(left); + Object literalValue = evaluateConstantExpression( + right, + type, + plannerContext, + session, + new AllowAllAccessControl(), + ImmutableMap.of()); + if (literalValue == null) { + // Possible when we process `x IN (..., NULL)` case. + return input.mapOutputRowCount(rowCountEstimate -> 0.); + } + OptionalDouble literal = toStatsRepresentation(type, literalValue); return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, literal, operator); } @@ -465,18 +477,6 @@ private boolean isEffectivelyLiteral(Expression expression) { return ExpressionUtils.isEffectivelyLiteral(plannerContext, session, expression); } - - private OptionalDouble doubleValueFromLiteral(Type type, Expression literal) - { - Object literalValue = evaluateConstantExpression( - literal, - type, - plannerContext, - session, - new AllowAllAccessControl(), - ImmutableMap.of()); - return toStatsRepresentation(type, literalValue); - } } private static List> extractCorrelatedGroups(List terms, double filterConjunctionIndependenceFactor) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java index a04de1f5f41f..0057cd4e5b30 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeSearcher.java @@ -109,16 +109,6 @@ private Optional findFirstRecursive(PlanNode node) return Optional.empty(); } - public Optional findSingle() - { - List all = findAll(); - return switch (all.size()) { - case 0 -> Optional.empty(); - case 1 -> Optional.of(all.get(0)); - default -> throw new IllegalStateException("Multiple nodes found"); - }; - } - /** * Return a list of matching nodes ordered as in pre-order traversal of the plan tree. */ @@ -134,15 +124,6 @@ public T findOnlyElement() return getOnlyElement(findAll()); } - public T findOnlyElement(T defaultValue) - { - List all = findAll(); - if (all.size() == 0) { - return defaultValue; - } - return getOnlyElement(all); - } - private void findAllRecursive(PlanNode node, ImmutableList.Builder nodes) { node = lookup.resolve(node); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index 679f8e2e775a..23736bed3339 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -609,7 +609,9 @@ public void testBetweenOperatorFilter() .nullsFraction(0.0)); assertExpression("'a' IN ('a', 'b')").equalTo(standardInputStatistics); + assertExpression("'a' IN ('a', 'b', NULL)").equalTo(standardInputStatistics); assertExpression("'a' IN ('b', 'c')").outputRowsCount(0); + assertExpression("'a' IN ('b', 'c', NULL)").outputRowsCount(0); assertExpression("CAST('b' AS VARCHAR(3)) IN (CAST('a' AS VARCHAR(3)), CAST('b' AS VARCHAR(3)))").equalTo(standardInputStatistics); assertExpression("CAST('c' AS VARCHAR(3)) IN (CAST('a' AS VARCHAR(3)), CAST('b' AS VARCHAR(3)))").outputRowsCount(0); } @@ -685,6 +687,15 @@ public void testInPredicateFilter() .highValue(7.5) .nullsFraction(0.0)); + // Multiple values some including NULL + assertExpression("x IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0, CAST(NULL AS double))") + .outputRowsCount(56.25) + .symbolStats("x", symbolStats -> + symbolStats.distinctValuesCount(3.0) + .lowValue(1.5) + .highValue(7.5) + .nullsFraction(0.0)); + // Multiple values in unknown range assertExpression("unknownRange IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0)") .outputRowsCount(90.0) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 54d8ae0e7a90..2931cade8edd 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -16,9 +16,11 @@ import com.google.common.collect.ImmutableList; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; import io.trino.spi.QueryId; import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SortOrder; +import io.trino.sql.planner.Plan; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ExchangeNode; @@ -60,6 +62,7 @@ import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.SystemSessionProperties.USE_MARK_DISTINCT; +import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_ENABLED; import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_WAIT_TIMEOUT; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.DOMAIN_COMPACTION_THRESHOLD; @@ -1703,17 +1706,17 @@ public void testWriteTaskParallelismSessionProperty(int parallelism, int numberO getQueryRunner()::execute, "write_parallelism", "(a varchar(128), b bigint)")) { - assertUpdate(session, "INSERT INTO " + table.getName() + " (a, b) SELECT clerk, orderkey FROM tpch.sf100.orders LIMIT " + numberOfRows, numberOfRows, plan -> { - TableWriterNode.WriterTarget target = searchFrom(plan.getRoot()) - .where(node -> node instanceof TableWriterNode) - .findFirst() - .map(TableWriterNode.class::cast) - .map(TableWriterNode::getTarget) - .orElseThrow(); - - assertThat(target.getMaxWriterTasks(getQueryRunner().getMetadata(), getSession())) - .hasValue(parallelism); - }); + Plan plan = getQueryRunner().createPlan( + session, + "INSERT INTO " + table.getName() + " (a, b) SELECT clerk, orderkey FROM tpch.sf100.orders LIMIT " + numberOfRows, + WarningCollector.NOOP, + createPlanOptimizersStatsCollector()); + TableWriterNode.WriterTarget target = ((TableWriterNode) searchFrom(plan.getRoot()) + .where(node -> node instanceof TableWriterNode) + .findOnlyElement()).getTarget(); + + assertThat(target.getMaxWriterTasks(getQueryRunner().getMetadata(), getSession())) + .hasValue(parallelism); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsReader.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsReader.java index 3907748e6c02..8e00d7e0c3da 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsReader.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsReader.java @@ -60,6 +60,7 @@ import static com.google.common.collect.Streams.stream; import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA; +import static io.trino.plugin.iceberg.IcebergMetadataColumn.isMetadataColumnId; import static io.trino.plugin.iceberg.IcebergSessionProperties.isExtendedStatisticsEnabled; import static io.trino.plugin.iceberg.IcebergUtil.getColumns; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -140,7 +141,8 @@ public static TableStatistics makeTableStatistics( .collect(toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); TableScan tableScan = icebergTable.newScan() - .filter(toIcebergExpression(effectivePredicate)) + // Table enforced constraint may include eg $path column predicate which is not handled by Iceberg library TODO apply $path and $file_modified_time filters here + .filter(toIcebergExpression(effectivePredicate.filter((column, domain) -> !isMetadataColumnId(column.getId())))) .useSnapshot(snapshotId) .includeColumnStats(); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index a39a6e0f55a5..e78c0917981a 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -5376,6 +5376,10 @@ public void testPathHiddenColumn() .returnsEmptyResult() .isFullyPushedDown(); + assertQuerySucceeds("SHOW STATS FOR (SELECT userid FROM " + tableName + " WHERE \"$path\" = '" + somePath + "')"); + // EXPLAIN triggers stats calculation and also rendering + assertQuerySucceeds("EXPLAIN SELECT userid FROM " + tableName + " WHERE \"$path\" = '" + somePath + "'"); + assertUpdate("DROP TABLE " + tableName); } @@ -5462,6 +5466,10 @@ public void testFileModifiedTimeHiddenColumn() assertThat(query("SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" IS NULL")) .returnsEmptyResult() .isFullyPushedDown(); + + assertQuerySucceeds("SHOW STATS FOR (SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" = from_iso8601_timestamp('" + fileModifiedTime.format(ISO_OFFSET_DATE_TIME) + "'))"); + // EXPLAIN triggers stats calculation and also rendering + assertQuerySucceeds("EXPLAIN SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" = from_iso8601_timestamp('" + fileModifiedTime.format(ISO_OFFSET_DATE_TIME) + "')"); } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 6e9a7affb07d..4776bb767f7c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -17,12 +17,14 @@ import com.google.common.collect.ImmutableMap; import com.google.common.io.Closer; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.inject.Key; import com.google.inject.Module; import io.airlift.discovery.server.testing.TestingDiscoveryServer; import io.airlift.log.Logger; import io.airlift.log.Logging; import io.airlift.testing.Assertions; import io.airlift.units.Duration; +import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.Session.SessionBuilder; import io.trino.cost.StatsCalculator; @@ -50,8 +52,11 @@ import io.trino.split.PageSourceManager; import io.trino.split.SplitManager; import io.trino.sql.analyzer.QueryExplainer; +import io.trino.sql.parser.ParsingOptions; +import io.trino.sql.parser.SqlParser; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.Plan; +import io.trino.sql.tree.Statement; import io.trino.testing.containers.OpenTracingCollector; import io.trino.transaction.TransactionManager; import org.intellij.lang.annotations.Language; @@ -80,6 +85,10 @@ import static io.airlift.log.Level.WARN; import static io.airlift.testing.Closeables.closeAllSuppress; import static io.airlift.units.Duration.nanosSince; +import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; +import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; +import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE; +import static io.trino.transaction.TransactionBuilder.transaction; import static java.lang.Boolean.parseBoolean; import static java.lang.System.getenv; import static java.util.Objects.requireNonNull; @@ -519,10 +528,18 @@ public MaterializedResultWithPlan executeWithPlan(Session session, String sql, W @Override public Plan createPlan(Session session, String sql, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector) { - QueryId queryId = executeWithQueryId(session, sql).getQueryId(); - Plan queryPlan = getQueryPlan(queryId); - coordinator.getQueryManager().cancelQuery(queryId); - return queryPlan; + if (session.getTransactionId().isEmpty()) { + return transaction(getTransactionManager(), getAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + return createPlan(transactionSession, sql, warningCollector, planOptimizersStatsCollector); + }); + } + + SqlParser sqlParser = coordinator.getInstance(Key.get(SqlParser.class)); + Statement statement = sqlParser.createStatement(sql, new ParsingOptions( + new FeaturesConfig().isParseDecimalLiteralsAsDouble() ? AS_DOUBLE : AS_DECIMAL)); + return coordinator.getQueryExplainer().getLogicalPlan(session, statement, ImmutableList.of(), WarningCollector.NOOP, createPlanOptimizersStatsCollector()); } public Plan getQueryPlan(QueryId queryId)