Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,19 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n
SymbolStatsEstimate leftStats = getExpressionStats(left);
Optional<Symbol> leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
if (isEffectivelyLiteral(right)) {
OptionalDouble literal = doubleValueFromLiteral(getType(left), right);
Type type = getType(left);
Object literalValue = evaluateConstantExpression(
right,
type,
plannerContext,
session,
new AllowAllAccessControl(),
ImmutableMap.of());
if (literalValue == null) {
// Possible when we process `x IN (..., NULL)` case.
return input.mapOutputRowCount(rowCountEstimate -> 0.);
}
OptionalDouble literal = toStatsRepresentation(type, literalValue);
return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, literal, operator);
}

Expand Down Expand Up @@ -465,18 +477,6 @@ private boolean isEffectivelyLiteral(Expression expression)
{
return ExpressionUtils.isEffectivelyLiteral(plannerContext, session, expression);
}

private OptionalDouble doubleValueFromLiteral(Type type, Expression literal)
{
Object literalValue = evaluateConstantExpression(
literal,
type,
plannerContext,
session,
new AllowAllAccessControl(),
ImmutableMap.of());
return toStatsRepresentation(type, literalValue);
}
}

private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,6 @@ private <T extends PlanNode> Optional<T> findFirstRecursive(PlanNode node)
return Optional.empty();
}

public <T extends PlanNode> Optional<T> findSingle()
{
List<T> all = findAll();
return switch (all.size()) {
case 0 -> Optional.empty();
case 1 -> Optional.of(all.get(0));
default -> throw new IllegalStateException("Multiple nodes found");
};
}

/**
* Return a list of matching nodes ordered as in pre-order traversal of the plan tree.
*/
Expand All @@ -134,15 +124,6 @@ public <T extends PlanNode> T findOnlyElement()
return getOnlyElement(findAll());
}

public <T extends PlanNode> T findOnlyElement(T defaultValue)
{
List<T> all = findAll();
if (all.size() == 0) {
return defaultValue;
}
return getOnlyElement(all);
}

private <T extends PlanNode> void findAllRecursive(PlanNode node, ImmutableList.Builder<T> nodes)
{
node = lookup.resolve(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,9 @@ public void testBetweenOperatorFilter()
.nullsFraction(0.0));

assertExpression("'a' IN ('a', 'b')").equalTo(standardInputStatistics);
assertExpression("'a' IN ('a', 'b', NULL)").equalTo(standardInputStatistics);
assertExpression("'a' IN ('b', 'c')").outputRowsCount(0);
assertExpression("'a' IN ('b', 'c', NULL)").outputRowsCount(0);
assertExpression("CAST('b' AS VARCHAR(3)) IN (CAST('a' AS VARCHAR(3)), CAST('b' AS VARCHAR(3)))").equalTo(standardInputStatistics);
assertExpression("CAST('c' AS VARCHAR(3)) IN (CAST('a' AS VARCHAR(3)), CAST('b' AS VARCHAR(3)))").outputRowsCount(0);
}
Expand Down Expand Up @@ -685,6 +687,15 @@ public void testInPredicateFilter()
.highValue(7.5)
.nullsFraction(0.0));

// Multiple values some including NULL
assertExpression("x IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0, CAST(NULL AS double))")
.outputRowsCount(56.25)
.symbolStats("x", symbolStats ->
symbolStats.distinctValuesCount(3.0)
.lowValue(1.5)
.highValue(7.5)
.nullsFraction(0.0));

// Multiple values in unknown range
assertExpression("unknownRange IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0)")
.outputRowsCount(90.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import com.google.common.collect.ImmutableList;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.spi.QueryId;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
Expand Down Expand Up @@ -60,6 +62,7 @@
import static com.google.common.collect.MoreCollectors.toOptional;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.trino.SystemSessionProperties.USE_MARK_DISTINCT;
import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector;
import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_ENABLED;
import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_WAIT_TIMEOUT;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.DOMAIN_COMPACTION_THRESHOLD;
Expand Down Expand Up @@ -1703,17 +1706,17 @@ public void testWriteTaskParallelismSessionProperty(int parallelism, int numberO
getQueryRunner()::execute,
"write_parallelism",
"(a varchar(128), b bigint)")) {
assertUpdate(session, "INSERT INTO " + table.getName() + " (a, b) SELECT clerk, orderkey FROM tpch.sf100.orders LIMIT " + numberOfRows, numberOfRows, plan -> {
TableWriterNode.WriterTarget target = searchFrom(plan.getRoot())
.where(node -> node instanceof TableWriterNode)
.findFirst()
.map(TableWriterNode.class::cast)
.map(TableWriterNode::getTarget)
.orElseThrow();

assertThat(target.getMaxWriterTasks(getQueryRunner().getMetadata(), getSession()))
.hasValue(parallelism);
});
Plan plan = getQueryRunner().createPlan(
Comment thread
findepi marked this conversation as resolved.
Outdated
session,
"INSERT INTO " + table.getName() + " (a, b) SELECT clerk, orderkey FROM tpch.sf100.orders LIMIT " + numberOfRows,
WarningCollector.NOOP,
createPlanOptimizersStatsCollector());
TableWriterNode.WriterTarget target = ((TableWriterNode) searchFrom(plan.getRoot())
.where(node -> node instanceof TableWriterNode)
.findOnlyElement()).getTarget();

assertThat(target.getMaxWriterTasks(getQueryRunner().getMetadata(), getSession()))
.hasValue(parallelism);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static com.google.common.collect.Streams.stream;
import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression;
import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA;
import static io.trino.plugin.iceberg.IcebergMetadataColumn.isMetadataColumnId;
import static io.trino.plugin.iceberg.IcebergSessionProperties.isExtendedStatisticsEnabled;
import static io.trino.plugin.iceberg.IcebergUtil.getColumns;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
Expand Down Expand Up @@ -140,7 +141,8 @@ public static TableStatistics makeTableStatistics(
.collect(toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));

TableScan tableScan = icebergTable.newScan()
.filter(toIcebergExpression(effectivePredicate))
// Table enforced constraint may include eg $path column predicate which is not handled by Iceberg library TODO apply $path and $file_modified_time filters here
.filter(toIcebergExpression(effectivePredicate.filter((column, domain) -> !isMetadataColumnId(column.getId()))))
.useSnapshot(snapshotId)
.includeColumnStats();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5376,6 +5376,10 @@ public void testPathHiddenColumn()
.returnsEmptyResult()
.isFullyPushedDown();

assertQuerySucceeds("SHOW STATS FOR (SELECT userid FROM " + tableName + " WHERE \"$path\" = '" + somePath + "')");
// EXPLAIN triggers stats calculation and also rendering
assertQuerySucceeds("EXPLAIN SELECT userid FROM " + tableName + " WHERE \"$path\" = '" + somePath + "'");

assertUpdate("DROP TABLE " + tableName);
}

Expand Down Expand Up @@ -5462,6 +5466,10 @@ public void testFileModifiedTimeHiddenColumn()
assertThat(query("SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" IS NULL"))
.returnsEmptyResult()
.isFullyPushedDown();

assertQuerySucceeds("SHOW STATS FOR (SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" = from_iso8601_timestamp('" + fileModifiedTime.format(ISO_OFFSET_DATE_TIME) + "'))");
// EXPLAIN triggers stats calculation and also rendering
assertQuerySucceeds("EXPLAIN SELECT col FROM " + table.getName() + " WHERE \"$file_modified_time\" = from_iso8601_timestamp('" + fileModifiedTime.format(ISO_OFFSET_DATE_TIME) + "')");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Closer;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.inject.Key;
import com.google.inject.Module;
import io.airlift.discovery.server.testing.TestingDiscoveryServer;
import io.airlift.log.Logger;
import io.airlift.log.Logging;
import io.airlift.testing.Assertions;
import io.airlift.units.Duration;
import io.trino.FeaturesConfig;
import io.trino.Session;
import io.trino.Session.SessionBuilder;
import io.trino.cost.StatsCalculator;
Expand Down Expand Up @@ -50,8 +52,11 @@
import io.trino.split.PageSourceManager;
import io.trino.split.SplitManager;
import io.trino.sql.analyzer.QueryExplainer;
import io.trino.sql.parser.ParsingOptions;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.Plan;
import io.trino.sql.tree.Statement;
import io.trino.testing.containers.OpenTracingCollector;
import io.trino.transaction.TransactionManager;
import org.intellij.lang.annotations.Language;
Expand Down Expand Up @@ -80,6 +85,10 @@
import static io.airlift.log.Level.WARN;
import static io.airlift.testing.Closeables.closeAllSuppress;
import static io.airlift.units.Duration.nanosSince;
import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector;
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static io.trino.transaction.TransactionBuilder.transaction;
import static java.lang.Boolean.parseBoolean;
import static java.lang.System.getenv;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -519,10 +528,18 @@ public MaterializedResultWithPlan executeWithPlan(Session session, String sql, W
@Override
public Plan createPlan(Session session, String sql, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector)
{
QueryId queryId = executeWithQueryId(session, sql).getQueryId();
Plan queryPlan = getQueryPlan(queryId);
coordinator.getQueryManager().cancelQuery(queryId);
return queryPlan;
if (session.getTransactionId().isEmpty()) {
return transaction(getTransactionManager(), getAccessControl())
.singleStatement()
.execute(session, transactionSession -> {
return createPlan(transactionSession, sql, warningCollector, planOptimizersStatsCollector);
});
}

SqlParser sqlParser = coordinator.getInstance(Key.get(SqlParser.class));
Statement statement = sqlParser.createStatement(sql, new ParsingOptions(
new FeaturesConfig().isParseDecimalLiteralsAsDouble() ? AS_DOUBLE : AS_DECIMAL));
return coordinator.getQueryExplainer().getLogicalPlan(session, statement, ImmutableList.of(), WarningCollector.NOOP, createPlanOptimizersStatsCollector());
}

public Plan getQueryPlan(QueryId queryId)
Expand Down