diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index f8e1df49d2e6..d49c33d6097a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -96,6 +96,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.sql.analyzer.QueryType.DESCRIBE; +import static io.trino.sql.analyzer.QueryType.EXPLAIN; +import static java.lang.Boolean.FALSE; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.unmodifiableMap; @@ -202,8 +205,7 @@ public class Analysis private Optional analyzeTarget = Optional.empty(); private Optional> updatedColumns = Optional.empty(); - // for describe input and describe output - private final boolean isDescribe; + private final QueryType queryType; // for recursive view detection private final Deque tablesForView = new ArrayDeque<>(); @@ -213,11 +215,11 @@ public class Analysis private final Multimap originColumnDetails = ArrayListMultimap.create(); private final Multimap, Field> fieldLineage = ArrayListMultimap.create(); - public Analysis(@Nullable Statement root, Map, Expression> parameters, boolean isDescribe) + public Analysis(@Nullable Statement root, Map, Expression> parameters, QueryType queryType) { this.root = root; this.parameters = ImmutableMap.copyOf(requireNonNull(parameters, "parameters is null")); - this.isDescribe = isDescribe; + this.queryType = requireNonNull(queryType, "queryType is null"); } public Statement getStatement() @@ -238,23 +240,25 @@ public Optional getTarget() }); } - public void setUpdateType(String updateType, QualifiedObjectName targetName, Optional
targetTable, Optional> targetColumns) + public void setUpdateType(String updateType) { - this.updateType = updateType; - this.target = Optional.of(new UpdateTarget(targetName, targetTable, targetColumns)); + if (queryType != EXPLAIN) { + this.updateType = updateType; + } } - public void resetUpdateType() + public void setUpdateTarget(QualifiedObjectName targetName, Optional
targetTable, Optional> targetColumns) { - this.updateType = null; - this.target = Optional.empty(); + this.target = Optional.of(new UpdateTarget(targetName, targetTable, targetColumns)); } public boolean isUpdateTarget(Table table) { - return ("DELETE".equals(updateType) || "UPDATE".equals(updateType)) && - target.orElseThrow(() -> new IllegalStateException("Update target not set")) - .getTable().orElseThrow(() -> new IllegalStateException("Table reference not set in update target")) == table; // intentional comparison by reference + requireNonNull(table, "table is null"); + return target + .flatMap(UpdateTarget::getTable) + .map(tableReference -> tableReference == table) // intentional comparison by reference + .orElse(FALSE); } public boolean isSkipMaterializedViewRefresh() @@ -840,9 +844,14 @@ public Map, Expression> getParameters() return parameters; } + public QueryType getQueryType() + { + return queryType; + } + public boolean isDescribe() { - return isDescribe; + return queryType == DESCRIBE; } public void setJoinUsing(Join node, JoinUsingAnalysis analysis) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analyzer.java index b179bedcef48..72d034bed3f6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analyzer.java @@ -38,6 +38,7 @@ import static io.trino.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions; import static io.trino.sql.analyzer.ExpressionTreeUtils.extractExpressions; import static io.trino.sql.analyzer.ExpressionTreeUtils.extractWindowExpressions; +import static io.trino.sql.analyzer.QueryType.OTHERS; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static java.util.Objects.requireNonNull; @@ -80,13 +81,13 @@ public Analyzer( public Analysis analyze(Statement statement) { - return analyze(statement, false); + return analyze(statement, OTHERS); } - public Analysis analyze(Statement statement, boolean isDescribe) + public Analysis analyze(Statement statement, QueryType queryType) { Statement rewrittenStatement = StatementRewrite.rewrite(session, metadata, sqlParser, queryExplainer, statement, parameters, parameterLookup, groupProvider, accessControl, warningCollector, statsCalculator); - Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, isDescribe); + Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, queryType); StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, groupProvider, accessControl, session, warningCollector, CorrelationSupport.ALLOWED); analyzer.analyze(rewrittenStatement, Optional.empty()); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 6e3cd67adeb0..5859e8ca83ed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -2539,9 +2539,9 @@ public static ExpressionAnalysis analyzeExpressions( Iterable expressions, Map, Expression> parameters, WarningCollector warningCollector, - boolean isDescribe) + QueryType queryType) { - Analysis analysis = new Analysis(null, parameters, isDescribe); + Analysis analysis = new Analysis(null, parameters, queryType); ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, groupProvider, accessControl, types, warningCollector); for (Expression expression : expressions) { analyzer.analyze( diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java index bb44b769e289..693739f41220 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java @@ -47,6 +47,7 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.sql.ParameterUtils.parameterExtractor; +import static io.trino.sql.analyzer.QueryType.EXPLAIN; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; import static io.trino.sql.planner.planprinter.IoPlanPrinter.textIoPlan; import static java.lang.String.format; @@ -118,7 +119,7 @@ public QueryExplainer( public Analysis analyze(Session session, Statement statement, List parameters, WarningCollector warningCollector) { Analyzer analyzer = new Analyzer(session, metadata, sqlParser, groupProvider, accessControl, Optional.of(this), parameters, parameterExtractor(statement, parameters), warningCollector, statsCalculator); - return analyzer.analyze(statement); + return analyzer.analyze(statement, EXPLAIN); } public String getPlan(Session session, Statement statement, Type planType, List parameters, WarningCollector warningCollector) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryType.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryType.java new file mode 100644 index 000000000000..e840cbfa4fb7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryType.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.analyzer; + +public enum QueryType +{ + DESCRIBE, + EXPLAIN, + OTHERS, + /**/; +} diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 89fd8da5df46..16a0c7cc1a48 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -502,8 +502,8 @@ protected Scope visitInsert(Insert insert, Optional scope) .map(Type::toString), Column::new); - analysis.setUpdateType( - "INSERT", + analysis.setUpdateType("INSERT"); + analysis.setUpdateTarget( targetTable, Optional.empty(), Optional.of(Streams.zip( @@ -526,11 +526,11 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate } accessControl.checkCanRefreshMaterializedView(session.toSecurityContext(), name); + analysis.setUpdateType("REFRESH MATERIALIZED VIEW"); if (metadata.delegateMaterializedViewRefreshToConnector(session, name)) { analysis.setDelegatedRefreshMaterializedView(name); - analysis.setUpdateType( - "REFRESH MATERIALIZED VIEW", + analysis.setUpdateTarget( name, Optional.empty(), Optional.empty()); @@ -590,8 +590,7 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate .map(Type::toString), Column::new); - analysis.setUpdateType( - "REFRESH MATERIALIZED VIEW", + analysis.setUpdateTarget( targetTable, Optional.empty(), Optional.of(Streams.zip( @@ -699,7 +698,8 @@ protected Scope visitDelete(Delete node, Optional scope) Scope tableScope = analyzer.analyzeForUpdate(table, scope, UpdateKind.DELETE); node.getWhere().ifPresent(where -> analyzeWhere(node, tableScope, where)); - analysis.setUpdateType("DELETE", tableName, Optional.of(table), Optional.empty()); + analysis.setUpdateType("DELETE"); + analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty()); return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT)); } @@ -708,7 +708,8 @@ protected Scope visitDelete(Delete node, Optional scope) protected Scope visitAnalyze(Analyze node, Optional scope) { QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName()); - analysis.setUpdateType("ANALYZE", tableName, Optional.empty(), Optional.empty()); + analysis.setUpdateType("ANALYZE"); + analysis.setUpdateTarget(tableName, Optional.empty(), Optional.empty()); // verify the target table exists and it's not a view if (metadata.getView(session, tableName).isPresent()) { @@ -763,7 +764,8 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional scope) validateColumns(node, queryScope.getRelationType()); - analysis.setUpdateType( - "CREATE VIEW", + analysis.setUpdateType("CREATE VIEW"); + analysis.setUpdateTarget( viewName, Optional.empty(), Optional.of(queryScope.getRelationType().getVisibleFields().stream() @@ -1069,8 +1071,8 @@ protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optiona validateColumns(node, queryScope.getRelationType()); - analysis.setUpdateType( - "CREATE MATERIALIZED VIEW", + analysis.setUpdateType("CREATE MATERIALIZED VIEW"); + analysis.setUpdateTarget( viewName, Optional.empty(), Optional.of( @@ -1153,7 +1155,6 @@ private void validateColumnAliasesCount(List columnAliases, int sour protected Scope visitExplain(Explain node, Optional scope) { process(node.getStatement(), scope); - analysis.resetUpdateType(); return createAndAssignScope(node, scope, Field.newUnqualified("Query Plan", VARCHAR)); } @@ -1161,7 +1162,6 @@ protected Scope visitExplain(Explain node, Optional scope) protected Scope visitExplainAnalyze(ExplainAnalyze node, Optional scope) { process(node.getStatement(), scope); - analysis.resetUpdateType(); return createAndAssignScope(node, scope, Field.newUnqualified("Query Plan", VARCHAR)); } @@ -1904,7 +1904,7 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s ImmutableList.of(samplePercentage), analysis.getParameters(), WarningCollector.NOOP, - analysis.isDescribe()) + analysis.getQueryType()) .getExpressionTypes(); Type samplePercentageType = expressionTypes.get(NodeRef.of(samplePercentage)); @@ -2313,8 +2313,8 @@ protected Scope visitUpdate(Update update, Optional scope) analysis.recordSubqueries(update, analyses.get(index)); } - analysis.setUpdateType( - "UPDATE", + analysis.setUpdateType("UPDATE"); + analysis.setUpdateTarget( tableName, Optional.of(table), Optional.of(updatedColumns.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java index ba81b7d31112..0bdbc1c40bb6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java @@ -30,6 +30,7 @@ import java.util.Map; import static io.trino.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; +import static io.trino.sql.analyzer.QueryType.OTHERS; /** * This class is to facilitate obtaining the type of an expression and its subexpressions @@ -59,7 +60,7 @@ public Map, Type> getTypes(Session session, TypeProvider inp expressions, ImmutableMap.of(), WarningCollector.NOOP, - false) + OTHERS) .getExpressionTypes(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index d09e28409d89..4bb84e8d2e41 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -22,10 +22,12 @@ import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SemiJoinNode; @@ -48,8 +50,8 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; import static io.trino.sql.planner.planprinter.PlanPrinter.textLogicalPlan; import static java.util.stream.Collectors.toSet; @@ -254,7 +256,9 @@ private WriterTarget createWriterTarget(WriterTarget target) private TableHandle findTableScanHandle(PlanNode node) { if (node instanceof TableScanNode) { - return ((TableScanNode) node).getTable(); + TableScanNode tableScanNode = (TableScanNode) node; + checkArgument(((TableScanNode) node).isUpdateTarget(), "TableScanNode should be an updatable target"); + return tableScanNode.getTable(); } if (node instanceof FilterNode) { return findTableScanHandle(((FilterNode) node).getSource()); @@ -267,9 +271,13 @@ private TableHandle findTableScanHandle(PlanNode node) } if (node instanceof JoinNode) { JoinNode joinNode = (JoinNode) node; - if (joinNode.getType() == JoinNode.Type.INNER && isAtMostScalar(joinNode.getRight())) { - return findTableScanHandle(joinNode.getLeft()); - } + return findTableScanHandle(joinNode.getLeft()); + } + if (node instanceof AssignUniqueId) { + return findTableScanHandle(((AssignUniqueId) node).getSource()); + } + if (node instanceof MarkDistinctNode) { + return findTableScanHandle(((MarkDistinctNode) node).getSource()); } throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + node.getClass().getName()); } @@ -303,11 +311,16 @@ private PlanNode rewriteModifyTableScan(PlanNode node, TableHandle handle) return replaceChildren(node, ImmutableList.of(source, ((SemiJoinNode) node).getFilteringSource())); } if (node instanceof JoinNode) { - JoinNode joinNode = (JoinNode) node; - if (joinNode.getType() == JoinNode.Type.INNER && isAtMostScalar(joinNode.getRight())) { - PlanNode source = rewriteModifyTableScan(joinNode.getLeft(), handle); - return replaceChildren(node, ImmutableList.of(source, joinNode.getRight())); - } + PlanNode source = rewriteModifyTableScan(((JoinNode) node).getLeft(), handle); + return replaceChildren(node, ImmutableList.of(source, ((JoinNode) node).getRight())); + } + if (node instanceof AssignUniqueId) { + PlanNode source = rewriteModifyTableScan(((AssignUniqueId) node).getSource(), handle); + return replaceChildren(node, ImmutableList.of(source)); + } + if (node instanceof MarkDistinctNode) { + PlanNode source = rewriteModifyTableScan(((MarkDistinctNode) node).getSource(), handle); + return replaceChildren(node, ImmutableList.of(source)); } throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + node.getClass().getName()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 61b453fa0a10..1ee622230347 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -230,7 +230,7 @@ public ActualProperties visitAssignUniqueId(AssignUniqueId node, List parameters = getParameters(statement); diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java index 5ee44a10e24a..99ebfc6d3871 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java @@ -53,6 +53,7 @@ import static io.trino.sql.QueryUtil.selectList; import static io.trino.sql.QueryUtil.simpleQuery; import static io.trino.sql.QueryUtil.values; +import static io.trino.sql.analyzer.QueryType.DESCRIBE; import static io.trino.type.TypeUtils.getDisplayLabel; import static java.util.Objects.requireNonNull; @@ -121,7 +122,7 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context) Statement statement = parser.createStatement(sqlString, createParsingOptions(session)); Analyzer analyzer = new Analyzer(session, metadata, parser, groupProvider, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, statsCalculator); - Analysis analysis = analyzer.analyze(statement, true); + Analysis analysis = analyzer.analyze(statement, DESCRIBE); Optional limit = Optional.empty(); Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 1584fa90372a..84559f930756 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -771,6 +771,19 @@ public void testCorrelatedScalarSubqueryInSelect() anyTree(tableScan("nation", ImmutableMap.of("n_regionkey", "regionkey"))))), anyTree( tableScan("region", ImmutableMap.of("r_regionkey", "regionkey"))))))))); + + assertDistributedPlan("SELECT name, (SELECT name FROM region WHERE regionkey = nation.regionkey) FROM nation", + automaticJoinDistribution(), + anyTree( + filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%s, 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), + project( + markDistinct("is_distinct", ImmutableList.of("unique"), + join(LEFT, ImmutableList.of(equiJoinClause("n_regionkey", "r_regionkey")), + project( + assignUniqueId("unique", + tableScan("nation", ImmutableMap.of("n_regionkey", "regionkey", "n_name", "name")))), + anyTree( + tableScan("region", ImmutableMap.of("r_regionkey", "regionkey"))))))))); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 88182fd1cf2c..2420d822f20f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -82,11 +82,13 @@ import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; import io.trino.sql.planner.plan.TopNRankingNode.RankingType; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UnnestNode; +import io.trino.sql.planner.plan.UpdateNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.planner.plan.WindowNode.Specification; @@ -95,7 +97,7 @@ import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Row; import io.trino.testing.TestingHandle; -import io.trino.testing.TestingMetadata; +import io.trino.testing.TestingMetadata.TestingColumnHandle; import io.trino.testing.TestingMetadata.TestingTableHandle; import io.trino.testing.TestingTransactionHandle; @@ -606,7 +608,7 @@ public TableScanBuilder setSymbols(List symbols) public TableScanBuilder setAssignmentsForSymbols(List symbols) { - return setAssignments(symbols.stream().collect(toImmutableMap(identity(), symbol -> new TestingMetadata.TestingColumnHandle(symbol.getName())))); + return setAssignments(symbols.stream().collect(toImmutableMap(identity(), symbol -> new TestingColumnHandle(symbol.getName())))); } public TableScanBuilder setAssignments(Map assignments) @@ -694,6 +696,49 @@ private DeleteTarget deleteTarget(SchemaTableName schemaTableName) schemaTableName); } + public TableFinishNode tableUpdate(SchemaTableName schemaTableName, PlanNode updateSource, Symbol updateRowId, List columnsToBeUpdated) + { + UpdateTarget updateTarget = updateTarget( + schemaTableName, + columnsToBeUpdated.stream() + .map(Symbol::getName) + .collect(toImmutableList())); + return new TableFinishNode( + idAllocator.getNextId(), + exchange(e -> e + .addSource(new UpdateNode( + idAllocator.getNextId(), + updateSource, + updateTarget, + updateRowId, + ImmutableList.builder() + .addAll(columnsToBeUpdated) + .add(updateRowId) + .build(), + ImmutableList.of(updateRowId))) + .addInputsSet(updateRowId) + .singleDistributionPartitioningScheme(updateRowId)), + updateTarget, + updateRowId, + Optional.empty(), + Optional.empty()); + } + + private UpdateTarget updateTarget(SchemaTableName schemaTableName, List columnsToBeUpdated) + { + return new UpdateTarget( + Optional.of(new TableHandle( + new CatalogName("testConnector"), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.of(TestingHandle.INSTANCE))), + schemaTableName, + columnsToBeUpdated, + columnsToBeUpdated.stream() + .map(TestingColumnHandle::new) + .collect(toImmutableList())); + } + public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) { return exchange(builder -> builder.type(ExchangeNode.Type.GATHER) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java new file mode 100644 index 000000000000..ae230ce04389 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.optimizations; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.AbstractMockMetadata; +import io.trino.metadata.Metadata; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.type.BigintType; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.PlanNode; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.function.Function; + +import static io.trino.sql.planner.TypeProvider.empty; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestBeginTableWrite +{ + @Test + public void testValidDelete() + { + assertThatCode(() -> applyOptimization( + p -> p.tableDelete( + new SchemaTableName("sch", "tab"), + p.tableScan(ImmutableList.of(p.symbol("rowId")), true), + p.symbol("rowId", BigintType.BIGINT)))) + .doesNotThrowAnyException(); + } + + @Test + public void testValidUpdate() + { + assertThatCode(() -> applyOptimization( + p -> p.tableUpdate( + new SchemaTableName("sch", "tab"), + p.tableScan(ImmutableList.of(p.symbol("columnToBeUpdated")), true), + p.symbol("rowId", BigintType.BIGINT), + ImmutableList.of(p.symbol("columnToBeUpdated"))))) + .doesNotThrowAnyException(); + } + + @Test + public void testDeleteWithNonDeletableTableScan() + { + assertThatThrownBy(() -> applyOptimization( + p -> p.tableDelete( + new SchemaTableName("sch", "tab"), + p.join( + INNER, + p.tableScan(ImmutableList.of(), false), + p.limit( + 1, + p.tableScan(ImmutableList.of(p.symbol("rowId")), true))), + p.symbol("rowId", BigintType.BIGINT)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("TableScanNode should be an updatable target"); + } + + @Test + public void testUpdateWithNonUpdatableTableScan() + { + assertThatThrownBy(() -> applyOptimization( + p -> p.tableUpdate( + new SchemaTableName("sch", "tab"), + p.join( + INNER, + p.tableScan(ImmutableList.of(), false), + p.limit( + 1, + p.tableScan(ImmutableList.of(p.symbol("columnToBeUpdated"), p.symbol("rowId")), true))), + p.symbol("rowId", BigintType.BIGINT), + ImmutableList.of(p.symbol("columnToBeUpdated"))))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("TableScanNode should be an updatable target"); + } + + @Test + public void testDeleteWithInvalidNode() + { + assertThatThrownBy(() -> applyOptimization( + p -> p.tableDelete( + new SchemaTableName("sch", "tab"), + p.distinctLimit( + 10, + ImmutableList.of(p.symbol("rowId")), + p.tableScan(ImmutableList.of(p.symbol("a")), true)), + p.symbol("rowId", BigintType.BIGINT)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode"); + } + + @Test + public void testUpdateWithInvalidNode() + { + assertThatThrownBy(() -> applyOptimization( + p -> p.tableUpdate( + new SchemaTableName("sch", "tab"), + p.distinctLimit( + 10, + ImmutableList.of(p.symbol("a"), p.symbol("rowId")), + p.tableScan(ImmutableList.of(p.symbol("a")), true)), + p.symbol("rowId", BigintType.BIGINT), + ImmutableList.of(p.symbol("columnToBeUpdated"))))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode"); + } + + private void applyOptimization(Function planProvider) + { + Metadata metadata = new MockMetadata(); + new BeginTableWrite(metadata) + .optimize( + planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata)), + testSessionBuilder().build(), + empty(), + new SymbolAllocator(), + new PlanNodeIdAllocator(), + WarningCollector.NOOP); + } + + private static class MockMetadata + extends AbstractMockMetadata + { + @Override + public TableHandle beginDelete(Session session, TableHandle tableHandle) + { + return tableHandle; + } + + @Override + public TableHandle beginUpdate(Session session, TableHandle tableHandle, List updatedColumns) + { + return tableHandle; + } + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 211d34d5b2b6..6dbb00151ec2 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1358,6 +1358,14 @@ public void testDeleteWithSubquery() .hasStackTraceContaining("TrinoException: Unsupported delete"); } + @Override + public void testExplainAnalyzeWithDeleteWithSubquery() + { + skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE)); + assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) + .hasStackTraceContaining("TrinoException: Unsupported delete"); + } + @Override public void testDeleteWithSemiJoin() { diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java index e39684b79e63..5df3bbf7551d 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java @@ -876,6 +876,13 @@ public void testDeleteWithSubquery() .hasStackTraceContaining("Delete without primary key or partition key is not supported"); } + @Override + public void testExplainAnalyzeWithDeleteWithSubquery() + { + assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) + .hasStackTraceContaining("Delete without primary key or partition key is not supported"); + } + @Override public void testDeleteWithVarcharPredicate() { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java index f5858763d7f9..4091af5c4254 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java @@ -242,6 +242,13 @@ public void testDeleteWithSubquery() .hasStackTraceContaining("Deletes must match whole partitions for non-transactional tables"); } + @Override + public void testExplainAnalyzeWithDeleteWithSubquery() + { + assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) + .hasStackTraceContaining("Deletes must match whole partitions for non-transactional tables"); + } + @Override public void testDeleteWithVarcharPredicate() { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 5041a41d4016..5d47edc0e174 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -168,6 +168,14 @@ public void testDeleteWithSubquery() .hasStackTraceContaining("This connector only supports delete where one or more identity-transformed partitions are deleted entirely"); } + @Override + public void testExplainAnalyzeWithDeleteWithSubquery() + { + // Deletes are covered with testMetadataDelete test methods + assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) + .hasStackTraceContaining("This connector only supports delete where one or more identity-transformed partitions are deleted entirely"); + } + @Override public void testDeleteWithVarcharPredicate() { diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java index 2e5a8c419dd3..d9473dc3f555 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java @@ -1378,7 +1378,6 @@ public void testAcidUpdateMajorCompaction() @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT) public void testAcidUpdateWithSubqueryPredicate() { - // TODO support UPDATE with correlated subquery in assignment withTemporaryTable("test_update_subquery", true, false, NONE, tableName -> { onTrino().executeQuery(format("CREATE TABLE %s (column1 INT, column2 varchar) WITH (transactional = true)", tableName)); onTrino().executeQuery(format("INSERT INTO %s VALUES (1, 'x')", tableName)); @@ -1397,16 +1396,17 @@ public void testAcidUpdateWithSubqueryPredicate() }); // WHERE with correlated subquery - assertQueryFailure(() -> onTrino().executeQuery(format("UPDATE %s SET column2 = 'row updated yet again' WHERE column2 = (SELECT name FROM tpch.tiny.region WHERE regionkey = column1)", tableName))) - // TODO (https://github.com/trinodb/trino/issues/3325) support correlated UPDATE - .hasMessageMatching("\\QQuery failed (#\\E\\S+\\Q): Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.MarkDistinctNode"); + onTrino().executeQuery(format("UPDATE %s SET column2 = 'row updated yet again' WHERE column2 = (SELECT name FROM tpch.tiny.region WHERE regionkey = column1)", tableName)); + verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "row updated"), row(2, "another row updated")); + + onTrino().executeQuery(format("UPDATE %s SET column2 = 'row updated yet again' WHERE column2 != (SELECT name FROM tpch.tiny.region WHERE regionkey = column1)", tableName)); + verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "row updated yet again"), row(2, "row updated yet again")); }); } @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT) public void testAcidUpdateWithSubqueryAssignment() { - // TODO support UPDATE with correlated subquery in assignment withTemporaryTable("test_update_subquery", true, false, NONE, tableName -> { onTrino().executeQuery(format("CREATE TABLE %s (column1 INT, column2 varchar) WITH (transactional = true)", tableName)); onTrino().executeQuery(format("INSERT INTO %s VALUES (1, 'x')", tableName)); @@ -1421,14 +1421,15 @@ public void testAcidUpdateWithSubqueryAssignment() // UPDATE while reading from another transactional table. Multiple transactional could interfere with ConnectorMetadata.beginQuery onTrino().executeQuery(format("UPDATE %s SET column2 = (SELECT min(name) FROM %s)", tableName, secondTable)); - // TODO (https://github.com/trinodb/trino/issues/8268) verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "AFRICA"), row(2, "AFRICA")); - verifySelect("onTrino", onTrino(), "SELECT * FROM " + tableName, "true", row(1, "AFRICA"), row(2, "AFRICA")); + verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "AFRICA"), row(2, "AFRICA")); + + onTrino().executeQuery(format("UPDATE %s SET column2 = (SELECT name FROM %s WHERE column1 = regionkey + 1)", tableName, secondTable)); + verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "AFRICA"), row(2, "AMERICA")); }); // SET with correlated subquery - assertQueryFailure(() -> onTrino().executeQuery(format("UPDATE %s SET column2 = (SELECT name FROM tpch.tiny.region WHERE column1 = regionkey)", tableName))) - // TODO (https://github.com/trinodb/trino/issues/3325) support correlated UPDATE - .hasMessageMatching("\\QQuery failed (#\\E\\S+\\Q): Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.MarkDistinctNode"); + onTrino().executeQuery(format("UPDATE %s SET column2 = (SELECT name FROM tpch.tiny.region WHERE column1 = regionkey)", tableName)); + verifySelectForTrinoAndHive("SELECT * FROM " + tableName, "true", row(1, "AMERICA"), row(2, "ASIA")); }); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestDistributedQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestDistributedQueries.java index 01d182e82a6d..864be9340d93 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestDistributedQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestDistributedQueries.java @@ -702,6 +702,20 @@ public void testDeleteWithSubquery() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testExplainAnalyzeWithDeleteWithSubquery() + { + skipTestUnlessSupportsDeletes(); + + String tableName = "test_delete_" + randomTableSuffix(); + + // delete using a subquery + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM nation", 25); + assertExplainAnalyze("EXPLAIN ANALYZE DELETE FROM " + tableName + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%' LIMIT 1)", + "SemiJoin.*"); + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testDeleteWithSemiJoin() {