From 9a5ebf89bd1ebb04a8cb19eb27b980469d78166d Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 31 Aug 2022 18:50:39 +0200 Subject: [PATCH 1/3] Improve argument order in ConnectorExpressionTranslator.translate - The expression and its symbol types should be passed next to each other. - TypeAnalyzer is usually provided after PlannerContext (until it's made part of it) --- .../io/trino/sql/planner/ConnectorExpressionTranslator.java | 4 ++-- .../trino/sql/planner/TestConnectorExpressionTranslator.java | 2 +- .../test/java/io/trino/sql/planner/TestPartialTranslator.java | 4 ++-- .../iterative/rule/TestPushProjectionIntoTableScan.java | 2 +- .../java/io/trino/plugin/iceberg/TestConstraintExtractor.java | 4 ++-- .../java/io/trino/plugin/postgresql/TestPostgreSqlClient.java | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 1f07a8073c4b..b2334c59e384 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -125,9 +125,9 @@ public static Expression translate(Session session, ConnectorExpression expressi .orElseThrow(() -> new UnsupportedOperationException("Expression is not supported: " + expression.toString())); } - public static Optional translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes, PlannerContext plannerContext) + public static Optional translate(Session session, Expression expression, TypeProvider types, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) { - return new SqlToConnectorExpressionTranslator(session, types.getTypes(session, inputTypes, expression), plannerContext) + return new SqlToConnectorExpressionTranslator(session, typeAnalyzer.getTypes(session, types, expression), plannerContext) .process(expression); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 0eea7ecde92a..c04f96848390 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -470,7 +470,7 @@ private void assertTranslationToConnectorExpression(Session session, Expression private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional connectorExpression) { - Optional translation = translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); + Optional translation = translate(session, expression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER); assertEquals(connectorExpression.isPresent(), translation.isPresent()); translation.ifPresent(value -> assertEquals(value, connectorExpression.get())); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index ce2fe443243d..fb6f674dd156 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -101,7 +101,7 @@ private void assertPartialTranslation(Expression expression, List su Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); assertEquals(subexpressions.size(), translation.size()); for (Expression subexpression : subexpressions) { - assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get()); + assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER).get()); } } @@ -109,6 +109,6 @@ private void assertFullTranslation(Expression expression) { Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); assertEquals(getOnlyElement(translation.keySet()), NodeRef.of(expression)); - assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get()); + assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER).get()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index 3c137bd286a9..f4e8b2ccf8cc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -154,7 +154,7 @@ constant, new LongLiteral("5"), TransactionId transactionId = ruleTester.getQueryRunner().getTransactionManager().beginTransaction(false); Session session = MOCK_SESSION.beginTransactionId(transactionId, ruleTester.getQueryRunner().getTransactionManager(), ruleTester.getQueryRunner().getAccessControl()); ImmutableMap connectorNames = inputProjections.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), typeAnalyzer, viewOf(types), ruleTester.getPlannerContext()).get().toString())); + .collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), viewOf(types), ruleTester.getPlannerContext(), typeAnalyzer).get().toString())); ImmutableMap newNames = ImmutableMap.of( identity, "projected_variable_" + connectorNames.get(identity), dereference, "projected_dereference_" + connectorNames.get(dereference), diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java index 566fa4b528b2..7b3981fb776b 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java @@ -239,10 +239,10 @@ private static ConnectorExpression connectorExpression(Expression expression, Ma return ConnectorExpressionTranslator.translate( TEST_SESSION, expression, - createTestingTypeAnalyzer(PLANNER_CONTEXT), TypeProvider.viewOf(symbolTypes.entrySet().stream() .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))), - PLANNER_CONTEXT) + PLANNER_CONTEXT, + createTestingTypeAnalyzer(PLANNER_CONTEXT)) .orElseThrow(); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 7aae0329b414..f25bff7155ce 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -446,10 +446,10 @@ private ConnectorExpression translateToConnectorExpression(Expression expression return ConnectorExpressionTranslator.translate( TEST_SESSION, expression, - createTestingTypeAnalyzer(PLANNER_CONTEXT), TypeProvider.viewOf(symbolTypes.entrySet().stream() .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Entry::getValue))), - PLANNER_CONTEXT) + PLANNER_CONTEXT, + createTestingTypeAnalyzer(PLANNER_CONTEXT)) .orElseThrow(); } } From cd733970d288c686d99ffa3d7216c02b2952691e Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 31 Aug 2022 18:56:42 +0200 Subject: [PATCH 2/3] Move conjuncts translation to ConnectorExpressionTranslator Move translation of individual conjuncts from `PushPredicateIntoTableScan` into `ConnectorExpressionTranslator` to make it reusable. This also changes `ConnectorExpressionTranslation` into a record, which is very suitable for a composite method return type. --- .../ConnectorExpressionTranslator.java | 43 +++++++++++ .../rule/PushPredicateIntoTableScan.java | 73 +++---------------- 2 files changed, 53 insertions(+), 63 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index b2334c59e384..2a7dc2fa62af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -21,6 +21,7 @@ import io.trino.Session; import io.trino.metadata.LiteralFunction; import io.trino.metadata.ResolvedFunction; +import io.trino.plugin.base.expression.ConnectorExpressions; import io.trino.security.AllowAllAccessControl; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; @@ -71,6 +72,7 @@ import io.trino.type.Re2JRegexp; import io.trino.type.Re2JRegexpType; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -105,6 +107,8 @@ import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.ExpressionUtils.combineConjuncts; +import static io.trino.sql.ExpressionUtils.extractConjuncts; import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; @@ -131,6 +135,36 @@ public static Optional translate(Session session, Expressio .process(expression); } + public static ConnectorExpressionTranslation translateConjuncts( + Session session, + Expression expression, + TypeProvider types, + PlannerContext plannerContext, + TypeAnalyzer typeAnalyzer) + { + Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression); + ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator( + session, + remainingExpressionTypes, + plannerContext); + + List conjuncts = extractConjuncts(expression); + List remaining = new ArrayList<>(); + List converted = new ArrayList<>(conjuncts.size()); + for (Expression conjunct : conjuncts) { + Optional connectorExpression = translator.process(conjunct); + if (connectorExpression.isPresent()) { + converted.add(connectorExpression.get()); + } + else { + remaining.add(conjunct); + } + } + return new ConnectorExpressionTranslation( + ConnectorExpressions.and(converted), + combineConjuncts(plannerContext.getMetadata(), remaining)); + } + @VisibleForTesting static FunctionName functionNameForComparisonOperator(ComparisonExpression.Operator operator) { @@ -157,6 +191,15 @@ static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpr }; } + public record ConnectorExpressionTranslation(ConnectorExpression connectorExpression, Expression remainingExpression) + { + public ConnectorExpressionTranslation + { + requireNonNull(connectorExpression, "connectorExpression is null"); + requireNonNull(remainingExpression, "remainingExpression is null"); + } + } + private static class ConnectorToSqlExpressionTranslator { private final Session session; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index d91d4ce49777..5e23ab5ec1cb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -25,7 +25,6 @@ import io.trino.metadata.TableHandle; import io.trino.metadata.TableProperties; import io.trino.metadata.TableProperties.TablePartitioning; -import io.trino.plugin.base.expression.ConnectorExpressions; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -35,6 +34,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.planner.ConnectorExpressionTranslator; +import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LayoutConstraintEvaluator; @@ -43,7 +43,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeAnalyzer; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.PlanNode; @@ -180,12 +179,12 @@ public static Optional pushFilterIntoTableScan( .transformKeys(node.getAssignments()::get) .intersect(node.getEnforcedConstraint()); - ConnectorExpressionTranslation expressionTranslation = translateConjunctsToConnectorExpression( + ConnectorExpressionTranslation expressionTranslation = ConnectorExpressionTranslator.translateConjuncts( session, - plannerContext, - typeAnalyzer, + decomposedPredicate.getRemainingExpression(), symbolAllocator.getTypes(), - decomposedPredicate.getRemainingExpression()); + plannerContext, + typeAnalyzer); Map connectorExpressionAssignments = node.getAssignments() .entrySet().stream() .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)); @@ -207,18 +206,18 @@ public static Optional pushFilterIntoTableScan( // Simplify the tuple domain to avoid creating an expression with too many nodes, // which would be expensive to evaluate in the call to isCandidate below. domainTranslator.toPredicate(session, newDomain.simplify().transformKeys(assignments::get)))); - constraint = new Constraint(newDomain, expressionTranslation.getConnectorExpression(), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments()); + constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments()); } else { // Currently, invoking the expression interpreter is very expensive. // TODO invoke the interpreter unconditionally when the interpreter becomes cheap enough. - constraint = new Constraint(newDomain, expressionTranslation.getConnectorExpression(), connectorExpressionAssignments); + constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments); } // check if new domain is wider than domain already provided by table scan if (constraint.predicate().isEmpty() && // TODO do we need to track enforced ConnectorExpression in TableScanNode? - TRUE.equals(expressionTranslation.getConnectorExpression()) && + TRUE.equals(expressionTranslation.connectorExpression()) && newDomain.contains(node.getEnforcedConstraint())) { Expression resultingPredicate = createResultingPredicate( plannerContext, @@ -276,7 +275,7 @@ public static Optional pushFilterIntoTableScan( node.getUseConnectorNodePartitioning()); Expression remainingDecomposedPredicate; - if (remainingConnectorExpression.isEmpty() || remainingConnectorExpression.get().equals(expressionTranslation.getConnectorExpression())) { + if (remainingConnectorExpression.isEmpty() || remainingConnectorExpression.get().equals(expressionTranslation.connectorExpression())) { remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression(); } else { @@ -292,7 +291,7 @@ public static Optional pushFilterIntoTableScan( new ExpressionInterpreter(translatedExpression, plannerContext, session, translatedExpressionTypes) .optimize(NoOpSymbolResolver.INSTANCE), translatedExpressionTypes.get(NodeRef.of(translatedExpression))); - remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.getRemainingExpression()); + remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.remainingExpression()); } Expression resultingPredicate = createResultingPredicate( @@ -358,36 +357,6 @@ else if (isDeterministic(conjunct, metadata)) { combineConjuncts(metadata, nonDeterministicPredicate)); } - private static ConnectorExpressionTranslation translateConjunctsToConnectorExpression( - Session session, - PlannerContext plannerContext, - TypeAnalyzer typeAnalyzer, - TypeProvider types, - Expression expression) - { - Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression); - ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator( - session, - remainingExpressionTypes, - plannerContext); - - List conjuncts = extractConjuncts(expression); - List remaining = new ArrayList<>(); - List converted = new ArrayList<>(conjuncts.size()); - for (Expression conjunct : conjuncts) { - Optional connectorExpression = translator.process(conjunct); - if (connectorExpression.isPresent()) { - converted.add(connectorExpression.get()); - } - else { - remaining.add(conjunct); - } - } - return new ConnectorExpressionTranslation( - combineConjuncts(plannerContext.getMetadata(), remaining), - ConnectorExpressions.and(converted)); - } - static Expression createResultingPredicate( PlannerContext plannerContext, Session session, @@ -478,26 +447,4 @@ public Expression getNonDeterministicPredicate() return nonDeterministicPredicate; } } - - private static class ConnectorExpressionTranslation - { - private final Expression remainingExpression; - private final ConnectorExpression connectorExpression; - - public ConnectorExpressionTranslation(Expression remainingExpression, ConnectorExpression connectorExpression) - { - this.remainingExpression = requireNonNull(remainingExpression, "remainingExpression is null"); - this.connectorExpression = requireNonNull(connectorExpression, "connectorExpression is null"); - } - - public Expression getRemainingExpression() - { - return remainingExpression; - } - - public ConnectorExpression getConnectorExpression() - { - return connectorExpression; - } - } } From 1a984bb0aaba1a7cabb7d72a4ec5026d48f39537 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 31 Aug 2022 18:43:38 +0200 Subject: [PATCH 3/3] Push whole join condition into connectors Previously join pushdown was limited to simple comparison conditions between variables. That's because the `applyJoin` was added before we had real connector expressions. The commit exposes to connectors all join condition's conjuncts that are translatable to `ConnectorExpression`. --- .../main/java/io/trino/metadata/Metadata.java | 3 +- .../io/trino/metadata/MetadataManager.java | 5 +- .../io/trino/sql/planner/PlanOptimizers.java | 2 +- .../iterative/rule/PushJoinIntoTableScan.java | 118 +++--------------- .../trino/metadata/AbstractMockMetadata.java | 3 +- .../metadata/CountingAccessMetadata.java | 5 +- .../rule/TestPushJoinIntoTableScan.java | 45 +++++-- .../spi/connector/ConnectorMetadata.java | 47 +++++++ .../io/trino/spi/connector/JoinCondition.java | 93 ++++++++++++-- .../ClassLoaderSafeConnectorMetadata.java | 16 +++ 10 files changed, 208 insertions(+), 129 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index 5efd5b94e953..453826b74717 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -31,7 +31,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -494,7 +493,7 @@ Optional> applyJoin( JoinType joinType, TableHandle left, TableHandle right, - List joinConditions, + ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics); diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index d4b1fb4f8e88..326c8a264b2a 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -62,7 +62,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -1619,7 +1618,7 @@ public Optional> applyJoin( JoinType joinType, TableHandle left, TableHandle right, - List joinConditions, + ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics) @@ -1640,7 +1639,7 @@ public Optional> applyJoin( joinType, left.getConnectorHandle(), right.getConnectorHandle(), - joinConditions, + joinCondition, leftAssignments, rightAssignments, statistics); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index b7f07d7ffea7..e93faab73995 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -836,7 +836,7 @@ public PlanOptimizers( .addAll(pushIntoTableScanRulesExceptJoins) // PushJoinIntoTableScan must run after ReorderJoins (and DetermineJoinDistributionType) // otherwise too early pushdown could prevent optimal plan from being selected. - .add(new PushJoinIntoTableScan(metadata)) + .add(new PushJoinIntoTableScan(plannerContext, typeAnalyzer)) // DetermineTableScanNodePartitioning is needed to needs to ensure all table handles have proper partitioning determined // Must run before AddExchanges .add(new DetermineTableScanNodePartitioning(metadata, nodePartitioningManager, taskCountEstimator)) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index ff64fe9d2eba..541070c5ab92 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -15,25 +15,24 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.trino.Session; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.spi.connector.BasicRelationStatistics; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; -import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.ExpressionUtils; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ConnectorExpressionTranslator; +import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -43,14 +42,11 @@ import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.tree.BooleanLiteral; -import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; -import io.trino.sql.tree.SymbolReference; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -59,7 +55,6 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.sql.ExpressionUtils.and; -import static io.trino.sql.ExpressionUtils.extractConjuncts; import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown; import static io.trino.sql.planner.plan.JoinNode.Type.FULL; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; @@ -81,11 +76,13 @@ public class PushJoinIntoTableScan .with(left().matching(tableScan().capturedAs(LEFT_TABLE_SCAN))) .with(right().matching(tableScan().capturedAs(RIGHT_TABLE_SCAN))); - private final Metadata metadata; + private final PlannerContext plannerContext; + private final TypeAnalyzer typeAnalyzer; - public PushJoinIntoTableScan(Metadata metadata) + public PushJoinIntoTableScan(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -113,9 +110,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan"); Expression effectiveFilter = getEffectiveFilter(joinNode); - FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context); + ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts( + context.getSession(), + effectiveFilter, + context.getSymbolAllocator().getTypes(), + plannerContext, + typeAnalyzer); - if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) { + if (!translation.remainingExpression().equals(BooleanLiteral.TRUE_LITERAL)) { // TODO add extra filter node above join return Result.empty(); } @@ -144,13 +146,13 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) */ JoinStatistics joinStatistics = getJoinStatistics(joinNode, left, right, context); - Optional> joinApplicationResult = metadata.applyJoin( + Optional> joinApplicationResult = plannerContext.getMetadata().applyJoin( context.getSession(), getJoinType(joinNode), left.getTable(), right.getTable(), - filterSplitResult.getPushableConditions(), - // TODO we could pass only subset of assignments here, those which are needed to resolve filterSplitResult.getPushableConditions + translation.connectorExpression(), + // TODO we could pass only subset of assignments here, those which are needed to resolve translation.getPushableConditions leftAssignments, rightAssignments, joinStatistics); @@ -254,88 +256,6 @@ public Expression getEffectiveFilter(JoinNode node) return effectiveFilter; } - private FilterSplitResult splitFilter(Expression filter, List leftSymbolsList, List rightSymbolsList, Context context) - { - Set leftSymbols = ImmutableSet.copyOf(leftSymbolsList); - Set rightSymbols = ImmutableSet.copyOf(rightSymbolsList); - - ImmutableList.Builder comparisonConditions = ImmutableList.builder(); - ImmutableList.Builder remainingConjuncts = ImmutableList.builder(); - - for (Expression conjunct : extractConjuncts(filter)) { - getPushableJoinCondition(conjunct, leftSymbols, rightSymbols, context) - .ifPresentOrElse(comparisonConditions::add, () -> remainingConjuncts.add(conjunct)); - } - - return new FilterSplitResult(comparisonConditions.build(), ExpressionUtils.and(remainingConjuncts.build())); - } - - private Optional getPushableJoinCondition(Expression conjunct, Set leftSymbols, Set rightSymbols, Context context) - { - if (!(conjunct instanceof ComparisonExpression)) { - return Optional.empty(); - } - ComparisonExpression comparison = (ComparisonExpression) conjunct; - - if (!(comparison.getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference)) { - return Optional.empty(); - } - Symbol left = Symbol.from(comparison.getLeft()); - Symbol right = Symbol.from(comparison.getRight()); - ComparisonExpression.Operator operator = comparison.getOperator(); - - if (!leftSymbols.contains(left)) { - // lets try with flipped expression - Symbol tmp = left; - left = right; - right = tmp; - operator = operator.flip(); - } - - if (leftSymbols.contains(left) && rightSymbols.contains(right)) { - return Optional.of(new JoinCondition( - joinConditionOperator(operator), - new Variable(left.getName(), context.getSymbolAllocator().getTypes().get(left)), - new Variable(right.getName(), context.getSymbolAllocator().getTypes().get(right)))); - } - return Optional.empty(); - } - - private static class FilterSplitResult - { - private final List pushableConditions; - private final Expression remainingFilter; - - public FilterSplitResult(List pushableConditions, Expression remainingFilter) - { - this.pushableConditions = requireNonNull(pushableConditions, "pushableConditions is null"); - this.remainingFilter = requireNonNull(remainingFilter, "remainingFilter is null"); - } - - public List getPushableConditions() - { - return pushableConditions; - } - - public Expression getRemainingFilter() - { - return remainingFilter; - } - } - - private JoinCondition.Operator joinConditionOperator(ComparisonExpression.Operator operator) - { - return switch (operator) { - case EQUAL -> JoinCondition.Operator.EQUAL; - case NOT_EQUAL -> JoinCondition.Operator.NOT_EQUAL; - case LESS_THAN -> JoinCondition.Operator.LESS_THAN; - case LESS_THAN_OR_EQUAL -> JoinCondition.Operator.LESS_THAN_OR_EQUAL; - case GREATER_THAN -> JoinCondition.Operator.GREATER_THAN; - case GREATER_THAN_OR_EQUAL -> JoinCondition.Operator.GREATER_THAN_OR_EQUAL; - case IS_DISTINCT_FROM -> JoinCondition.Operator.IS_DISTINCT_FROM; - }; - } - private JoinType getJoinType(JoinNode joinNode) { return switch (joinNode.getType()) { diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index a22f5aba756f..b5bee784985d 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -37,7 +37,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -611,7 +610,7 @@ public Optional> applyJoin( JoinType joinType, TableHandle left, TableHandle right, - List joinConditions, + ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics) diff --git a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java index d4dcee212246..6a9a87bc2553 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java @@ -33,7 +33,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -611,9 +610,9 @@ public Optional> applyAggregation(Sess } @Override - public Optional> applyJoin(Session session, JoinType joinType, TableHandle left, TableHandle right, List joinConditions, Map leftAssignments, Map rightAssignments, JoinStatistics statistics) + public Optional> applyJoin(Session session, JoinType joinType, TableHandle left, TableHandle right, ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics) { - return delegate.applyJoin(session, joinType, left, right, joinConditions, leftAssignments, rightAssignments, statistics); + return delegate.applyJoin(session, joinType, left, right, joinCondition, leftAssignments, rightAssignments, statistics); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index 0586ed1f87c5..e1bf4f6ff83b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -29,6 +29,8 @@ import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.Constant; import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; @@ -51,6 +53,7 @@ import static com.google.common.base.Predicates.equalTo; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -143,7 +146,7 @@ public void testPushJoinIntoTableScan(JoinNode.Type joinType, Optional { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -221,15 +224,33 @@ public static Object[][] testPushJoinIntoTableScanParams() }; } + /** + * Test a scenario where join condition cannot be represented with simple comparisons. + */ @Test - public void testPushJoinIntoTableScanDoesNotTriggerWithUnsupportedFilter() + public void testPushJoinIntoTableScanWithComplexFilter() { MockConnectorFactory connectorFactory = createMockConnectorFactory( (session, applyJoinType, left, right, joinConditions, leftAssignments, rightAssignments) -> { - throw new IllegalStateException("applyJoin should not be called!"); + assertThat(joinConditions).as("joinConditions") + .isEqualTo(List.of( + new JoinCondition( + JoinCondition.Operator.GREATER_THAN, + new Call( + BIGINT, + MULTIPLY_FUNCTION_NAME, + List.of( + new Constant(44L, BIGINT), + new Variable("columna1", BIGINT))), + new Variable("columnb1", BIGINT)))); + return Optional.of(new JoinApplicationResult<>( + JOIN_CONNECTOR_TABLE_HANDLE, + JOIN_TABLE_A_COLUMN_MAPPING, + JOIN_TABLE_B_COLUMN_MAPPING, + false)); }); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -255,7 +276,9 @@ public void testPushJoinIntoTableScanDoesNotTriggerWithUnsupportedFilter() columnB1Symbol.toSymbolReference())); }) .withSession(MOCK_SESSION) - .doesNotFire(); + .matches( + project( + tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName()))); } } @@ -270,7 +293,7 @@ public void testPushJoinIntoTableScanDoesNotFireForDifferentCatalogs() ruleTester.getQueryRunner().createCatalog("another_catalog", "mock", ImmutableMap.of()); TableHandle tableBHandleAnotherCatalog = createTableHandle(new MockConnectorTableHandle(new SchemaTableName(SCHEMA, TABLE_B)), createTestCatalogHandle("another_catalog")); - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -309,7 +332,7 @@ public void testPushJoinIntoTableScanDoesNotFireWhenDisabled() throw new IllegalStateException("applyJoin should not be called!"); }); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -348,7 +371,7 @@ public void testPushJoinIntoTableScanDoesNotFireWhenAllPushdownsDisabled() throw new IllegalStateException("applyJoin should not be called!"); }); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -384,7 +407,7 @@ public void testPushJoinIntoTableScanPreservesEnforcedConstraint(JoinNode.Type j JOIN_TABLE_B_COLUMN_MAPPING, false))); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -485,7 +508,7 @@ public void testPushJoinIntoTableDoesNotFireForCrossJoin() throw new IllegalStateException("applyJoin should not be called!"); }); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); @@ -524,7 +547,7 @@ public void testPushJoinIntoTableRequiresFullColumnHandleMappingInResult() false))); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) { assertThatThrownBy(() -> { - ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer())) .on(p -> { Symbol columnA1Symbol = p.symbol(COLUMN_A1); Symbol columnA2Symbol = p.symbol(COLUMN_A2); diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index aa0bc2f7a015..55ab1cfa1e13 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -15,6 +15,7 @@ import io.airlift.slice.Slice; import io.trino.spi.TrinoException; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; import io.trino.spi.expression.Variable; @@ -36,6 +37,7 @@ import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -49,6 +51,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.stream.Collectors.toUnmodifiableList; @@ -1282,6 +1285,50 @@ default Optional> applyAggreg * It is required that mapping is provided for *all* column handles exposed previously by both left and right join sources. *

*/ + default Optional> applyJoin( + ConnectorSession session, + JoinType joinType, + ConnectorTableHandle left, + ConnectorTableHandle right, + ConnectorExpression joinCondition, + Map leftAssignments, + Map rightAssignments, + JoinStatistics statistics) + { + List conditions; + if (joinCondition instanceof Call call && AND_FUNCTION_NAME.equals(call.getFunctionName())) { + conditions = new ArrayList<>(call.getArguments().size()); + for (ConnectorExpression argument : call.getArguments()) { + if (Constant.TRUE.equals(argument)) { + continue; + } + Optional condition = JoinCondition.from(argument, leftAssignments.keySet(), rightAssignments.keySet()); + if (condition.isEmpty()) { + // We would need to add a FilterNode on top of the result + return Optional.empty(); + } + conditions.add(condition.get()); + } + } + else { + Optional condition = JoinCondition.from(joinCondition, leftAssignments.keySet(), rightAssignments.keySet()); + if (condition.isEmpty()) { + return Optional.empty(); + } + conditions = List.of(condition.get()); + } + return applyJoin( + session, + joinType, + left, + right, + conditions, + leftAssignments, + rightAssignments, + statistics); + } + + @Deprecated default Optional> applyJoin( ConnectorSession session, JoinType joinType, diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java index 4d2b27c12e20..83989c1066cf 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java @@ -13,36 +13,113 @@ */ package io.trino.spi.connector; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.StandardFunctions; +import io.trino.spi.expression.Variable; +import java.util.ArrayDeque; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Stream; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toUnmodifiableMap; +@Deprecated public final class JoinCondition { + @Deprecated public enum Operator { - EQUAL("="), - NOT_EQUAL("<>"), - LESS_THAN("<"), - LESS_THAN_OR_EQUAL("<="), - GREATER_THAN(">"), - GREATER_THAN_OR_EQUAL(">="), - IS_DISTINCT_FROM("IS DISTINCT FROM"); + EQUAL("=", StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME), + NOT_EQUAL("<>", StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME), + LESS_THAN("<", StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME), + LESS_THAN_OR_EQUAL("<=", StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME), + GREATER_THAN(">", StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME), + GREATER_THAN_OR_EQUAL(">=", StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME), + IS_DISTINCT_FROM("IS DISTINCT FROM", StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME), + /**/; + + private static final Map byFunctionName = Stream.of(values()) + .collect(toUnmodifiableMap(operator -> operator.callFunctionName, identity())); private final String value; + private final FunctionName callFunctionName; - Operator(String value) + Operator(String value, FunctionName callFunctionName) { this.value = value; + this.callFunctionName = callFunctionName; } public String getValue() { return value; } + + public Operator flip() + { + return switch (this) { + case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> this; + case LESS_THAN -> GREATER_THAN; + case LESS_THAN_OR_EQUAL -> GREATER_THAN_OR_EQUAL; + case GREATER_THAN -> LESS_THAN; + case GREATER_THAN_OR_EQUAL -> LESS_THAN_OR_EQUAL; + }; + } + } + + public static Optional from(ConnectorExpression expression, Set leftSymbols, Set rightSymbols) + { + if (expression instanceof Call call && call.getArguments().size() == 2) { + return Optional.ofNullable(Operator.byFunctionName.get(call.getFunctionName())) + .flatMap(operator -> { + rightSymbols.stream().filter(leftSymbols::contains).findAny().ifPresent(symbol -> { + throw new IllegalArgumentException( + "Left and right symbol sets overlap, are both include %s: %s, %s".formatted(symbol, leftSymbols, rightSymbols)); + }); + ConnectorExpression left = call.getArguments().get(0); + ConnectorExpression right = call.getArguments().get(1); + Set leftExpressionSymbols = findVariableNames(left); + Set rightExpressionSymbols = findVariableNames(right); + if (leftSymbols.containsAll(leftExpressionSymbols) && rightSymbols.containsAll(rightExpressionSymbols)) { + return Optional.of(new JoinCondition(operator, left, right)); + } + if (rightSymbols.containsAll(leftExpressionSymbols) && leftSymbols.containsAll(rightExpressionSymbols)) { + // normalize + return Optional.of(new JoinCondition(operator.flip(), right, left)); + } + return Optional.empty(); + }); + } + return Optional.empty(); + } + + private static Set findVariableNames(ConnectorExpression expression) + { + Set variableNames = new HashSet<>(); + Set visited = new HashSet<>(); + Queue pending = new ArrayDeque<>(List.of(expression)); + while (!pending.isEmpty()) { + ConnectorExpression next = pending.remove(); + if (!visited.add(next)) { + continue; + } + pending.addAll(next.getChildren()); + if (next instanceof Variable variable) { + variableNames.add(variable.getName()); + } + } + return variableNames; } private final Operator operator; diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 5f0ec0cf1d3a..44a1ca238ed9 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -908,6 +908,22 @@ public Optional> applyAggrega } } + @Override + public Optional> applyJoin( + ConnectorSession session, + JoinType joinType, + ConnectorTableHandle left, + ConnectorTableHandle right, + ConnectorExpression joinCondition, + Map leftAssignments, + Map rightAssignments, + JoinStatistics statistics) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.applyJoin(session, joinType, left, right, joinCondition, leftAssignments, rightAssignments, statistics); + } + } + @Override public Optional> applyJoin( ConnectorSession session,