From 9a5ebf89bd1ebb04a8cb19eb27b980469d78166d Mon Sep 17 00:00:00 2001
From: Piotr Findeisen
Date: Wed, 31 Aug 2022 18:50:39 +0200
Subject: [PATCH 1/3] Improve argument order in
ConnectorExpressionTranslator.translate
- The expression and its symbol types should be passed next to each other.
- TypeAnalyzer is usually provided after PlannerContext (until it's made
part of it)
---
.../io/trino/sql/planner/ConnectorExpressionTranslator.java | 4 ++--
.../trino/sql/planner/TestConnectorExpressionTranslator.java | 2 +-
.../test/java/io/trino/sql/planner/TestPartialTranslator.java | 4 ++--
.../iterative/rule/TestPushProjectionIntoTableScan.java | 2 +-
.../java/io/trino/plugin/iceberg/TestConstraintExtractor.java | 4 ++--
.../java/io/trino/plugin/postgresql/TestPostgreSqlClient.java | 4 ++--
6 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
index 1f07a8073c4b..b2334c59e384 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
@@ -125,9 +125,9 @@ public static Expression translate(Session session, ConnectorExpression expressi
.orElseThrow(() -> new UnsupportedOperationException("Expression is not supported: " + expression.toString()));
}
- public static Optional translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes, PlannerContext plannerContext)
+ public static Optional translate(Session session, Expression expression, TypeProvider types, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
{
- return new SqlToConnectorExpressionTranslator(session, types.getTypes(session, inputTypes, expression), plannerContext)
+ return new SqlToConnectorExpressionTranslator(session, typeAnalyzer.getTypes(session, types, expression), plannerContext)
.process(expression);
}
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java
index 0eea7ecde92a..c04f96848390 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java
@@ -470,7 +470,7 @@ private void assertTranslationToConnectorExpression(Session session, Expression
private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional connectorExpression)
{
- Optional translation = translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT);
+ Optional translation = translate(session, expression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER);
assertEquals(connectorExpression.isPresent(), translation.isPresent());
translation.ifPresent(value -> assertEquals(value, connectorExpression.get()));
}
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java
index ce2fe443243d..fb6f674dd156 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java
@@ -101,7 +101,7 @@ private void assertPartialTranslation(Expression expression, List su
Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT);
assertEquals(subexpressions.size(), translation.size());
for (Expression subexpression : subexpressions) {
- assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get());
+ assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER).get());
}
}
@@ -109,6 +109,6 @@ private void assertFullTranslation(Expression expression)
{
Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT);
assertEquals(getOnlyElement(translation.keySet()), NodeRef.of(expression));
- assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get());
+ assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_PROVIDER, PLANNER_CONTEXT, TYPE_ANALYZER).get());
}
}
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java
index 3c137bd286a9..f4e8b2ccf8cc 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java
@@ -154,7 +154,7 @@ constant, new LongLiteral("5"),
TransactionId transactionId = ruleTester.getQueryRunner().getTransactionManager().beginTransaction(false);
Session session = MOCK_SESSION.beginTransactionId(transactionId, ruleTester.getQueryRunner().getTransactionManager(), ruleTester.getQueryRunner().getAccessControl());
ImmutableMap connectorNames = inputProjections.entrySet().stream()
- .collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), typeAnalyzer, viewOf(types), ruleTester.getPlannerContext()).get().toString()));
+ .collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), viewOf(types), ruleTester.getPlannerContext(), typeAnalyzer).get().toString()));
ImmutableMap newNames = ImmutableMap.of(
identity, "projected_variable_" + connectorNames.get(identity),
dereference, "projected_dereference_" + connectorNames.get(dereference),
diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java
index 566fa4b528b2..7b3981fb776b 100644
--- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java
+++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java
@@ -239,10 +239,10 @@ private static ConnectorExpression connectorExpression(Expression expression, Ma
return ConnectorExpressionTranslator.translate(
TEST_SESSION,
expression,
- createTestingTypeAnalyzer(PLANNER_CONTEXT),
TypeProvider.viewOf(symbolTypes.entrySet().stream()
.collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))),
- PLANNER_CONTEXT)
+ PLANNER_CONTEXT,
+ createTestingTypeAnalyzer(PLANNER_CONTEXT))
.orElseThrow();
}
diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java
index 7aae0329b414..f25bff7155ce 100644
--- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java
+++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java
@@ -446,10 +446,10 @@ private ConnectorExpression translateToConnectorExpression(Expression expression
return ConnectorExpressionTranslator.translate(
TEST_SESSION,
expression,
- createTestingTypeAnalyzer(PLANNER_CONTEXT),
TypeProvider.viewOf(symbolTypes.entrySet().stream()
.collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Entry::getValue))),
- PLANNER_CONTEXT)
+ PLANNER_CONTEXT,
+ createTestingTypeAnalyzer(PLANNER_CONTEXT))
.orElseThrow();
}
}
From cd733970d288c686d99ffa3d7216c02b2952691e Mon Sep 17 00:00:00 2001
From: Piotr Findeisen
Date: Wed, 31 Aug 2022 18:56:42 +0200
Subject: [PATCH 2/3] Move conjuncts translation to
ConnectorExpressionTranslator
Move translation of individual conjuncts from `PushPredicateIntoTableScan`
into `ConnectorExpressionTranslator` to make it reusable.
This also changes `ConnectorExpressionTranslation` into a record, which
is very suitable for a composite method return type.
---
.../ConnectorExpressionTranslator.java | 43 +++++++++++
.../rule/PushPredicateIntoTableScan.java | 73 +++----------------
2 files changed, 53 insertions(+), 63 deletions(-)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
index b2334c59e384..2a7dc2fa62af 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java
@@ -21,6 +21,7 @@
import io.trino.Session;
import io.trino.metadata.LiteralFunction;
import io.trino.metadata.ResolvedFunction;
+import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
@@ -71,6 +72,7 @@
import io.trino.type.Re2JRegexp;
import io.trino.type.Re2JRegexpType;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -105,6 +107,8 @@
import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.createVarcharType;
+import static io.trino.sql.ExpressionUtils.combineConjuncts;
+import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature;
@@ -131,6 +135,36 @@ public static Optional translate(Session session, Expressio
.process(expression);
}
+ public static ConnectorExpressionTranslation translateConjuncts(
+ Session session,
+ Expression expression,
+ TypeProvider types,
+ PlannerContext plannerContext,
+ TypeAnalyzer typeAnalyzer)
+ {
+ Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression);
+ ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(
+ session,
+ remainingExpressionTypes,
+ plannerContext);
+
+ List conjuncts = extractConjuncts(expression);
+ List remaining = new ArrayList<>();
+ List converted = new ArrayList<>(conjuncts.size());
+ for (Expression conjunct : conjuncts) {
+ Optional connectorExpression = translator.process(conjunct);
+ if (connectorExpression.isPresent()) {
+ converted.add(connectorExpression.get());
+ }
+ else {
+ remaining.add(conjunct);
+ }
+ }
+ return new ConnectorExpressionTranslation(
+ ConnectorExpressions.and(converted),
+ combineConjuncts(plannerContext.getMetadata(), remaining));
+ }
+
@VisibleForTesting
static FunctionName functionNameForComparisonOperator(ComparisonExpression.Operator operator)
{
@@ -157,6 +191,15 @@ static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpr
};
}
+ public record ConnectorExpressionTranslation(ConnectorExpression connectorExpression, Expression remainingExpression)
+ {
+ public ConnectorExpressionTranslation
+ {
+ requireNonNull(connectorExpression, "connectorExpression is null");
+ requireNonNull(remainingExpression, "remainingExpression is null");
+ }
+ }
+
private static class ConnectorToSqlExpressionTranslator
{
private final Session session;
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java
index d91d4ce49777..5e23ab5ec1cb 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java
@@ -25,7 +25,6 @@
import io.trino.metadata.TableHandle;
import io.trino.metadata.TableProperties;
import io.trino.metadata.TableProperties.TablePartitioning;
-import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
@@ -35,6 +34,7 @@
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ConnectorExpressionTranslator;
+import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LayoutConstraintEvaluator;
@@ -43,7 +43,6 @@
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeAnalyzer;
-import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
@@ -180,12 +179,12 @@ public static Optional pushFilterIntoTableScan(
.transformKeys(node.getAssignments()::get)
.intersect(node.getEnforcedConstraint());
- ConnectorExpressionTranslation expressionTranslation = translateConjunctsToConnectorExpression(
+ ConnectorExpressionTranslation expressionTranslation = ConnectorExpressionTranslator.translateConjuncts(
session,
- plannerContext,
- typeAnalyzer,
+ decomposedPredicate.getRemainingExpression(),
symbolAllocator.getTypes(),
- decomposedPredicate.getRemainingExpression());
+ plannerContext,
+ typeAnalyzer);
Map connectorExpressionAssignments = node.getAssignments()
.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));
@@ -207,18 +206,18 @@ public static Optional pushFilterIntoTableScan(
// Simplify the tuple domain to avoid creating an expression with too many nodes,
// which would be expensive to evaluate in the call to isCandidate below.
domainTranslator.toPredicate(session, newDomain.simplify().transformKeys(assignments::get))));
- constraint = new Constraint(newDomain, expressionTranslation.getConnectorExpression(), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments());
+ constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments());
}
else {
// Currently, invoking the expression interpreter is very expensive.
// TODO invoke the interpreter unconditionally when the interpreter becomes cheap enough.
- constraint = new Constraint(newDomain, expressionTranslation.getConnectorExpression(), connectorExpressionAssignments);
+ constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments);
}
// check if new domain is wider than domain already provided by table scan
if (constraint.predicate().isEmpty() &&
// TODO do we need to track enforced ConnectorExpression in TableScanNode?
- TRUE.equals(expressionTranslation.getConnectorExpression()) &&
+ TRUE.equals(expressionTranslation.connectorExpression()) &&
newDomain.contains(node.getEnforcedConstraint())) {
Expression resultingPredicate = createResultingPredicate(
plannerContext,
@@ -276,7 +275,7 @@ public static Optional pushFilterIntoTableScan(
node.getUseConnectorNodePartitioning());
Expression remainingDecomposedPredicate;
- if (remainingConnectorExpression.isEmpty() || remainingConnectorExpression.get().equals(expressionTranslation.getConnectorExpression())) {
+ if (remainingConnectorExpression.isEmpty() || remainingConnectorExpression.get().equals(expressionTranslation.connectorExpression())) {
remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression();
}
else {
@@ -292,7 +291,7 @@ public static Optional pushFilterIntoTableScan(
new ExpressionInterpreter(translatedExpression, plannerContext, session, translatedExpressionTypes)
.optimize(NoOpSymbolResolver.INSTANCE),
translatedExpressionTypes.get(NodeRef.of(translatedExpression)));
- remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.getRemainingExpression());
+ remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.remainingExpression());
}
Expression resultingPredicate = createResultingPredicate(
@@ -358,36 +357,6 @@ else if (isDeterministic(conjunct, metadata)) {
combineConjuncts(metadata, nonDeterministicPredicate));
}
- private static ConnectorExpressionTranslation translateConjunctsToConnectorExpression(
- Session session,
- PlannerContext plannerContext,
- TypeAnalyzer typeAnalyzer,
- TypeProvider types,
- Expression expression)
- {
- Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression);
- ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(
- session,
- remainingExpressionTypes,
- plannerContext);
-
- List conjuncts = extractConjuncts(expression);
- List remaining = new ArrayList<>();
- List converted = new ArrayList<>(conjuncts.size());
- for (Expression conjunct : conjuncts) {
- Optional connectorExpression = translator.process(conjunct);
- if (connectorExpression.isPresent()) {
- converted.add(connectorExpression.get());
- }
- else {
- remaining.add(conjunct);
- }
- }
- return new ConnectorExpressionTranslation(
- combineConjuncts(plannerContext.getMetadata(), remaining),
- ConnectorExpressions.and(converted));
- }
-
static Expression createResultingPredicate(
PlannerContext plannerContext,
Session session,
@@ -478,26 +447,4 @@ public Expression getNonDeterministicPredicate()
return nonDeterministicPredicate;
}
}
-
- private static class ConnectorExpressionTranslation
- {
- private final Expression remainingExpression;
- private final ConnectorExpression connectorExpression;
-
- public ConnectorExpressionTranslation(Expression remainingExpression, ConnectorExpression connectorExpression)
- {
- this.remainingExpression = requireNonNull(remainingExpression, "remainingExpression is null");
- this.connectorExpression = requireNonNull(connectorExpression, "connectorExpression is null");
- }
-
- public Expression getRemainingExpression()
- {
- return remainingExpression;
- }
-
- public ConnectorExpression getConnectorExpression()
- {
- return connectorExpression;
- }
- }
}
From 1a984bb0aaba1a7cabb7d72a4ec5026d48f39537 Mon Sep 17 00:00:00 2001
From: Piotr Findeisen
Date: Wed, 31 Aug 2022 18:43:38 +0200
Subject: [PATCH 3/3] Push whole join condition into connectors
Previously join pushdown was limited to simple comparison conditions
between variables. That's because the `applyJoin` was added before we
had real connector expressions. The commit exposes to connectors all
join condition's conjuncts that are translatable to
`ConnectorExpression`.
---
.../main/java/io/trino/metadata/Metadata.java | 3 +-
.../io/trino/metadata/MetadataManager.java | 5 +-
.../io/trino/sql/planner/PlanOptimizers.java | 2 +-
.../iterative/rule/PushJoinIntoTableScan.java | 118 +++---------------
.../trino/metadata/AbstractMockMetadata.java | 3 +-
.../metadata/CountingAccessMetadata.java | 5 +-
.../rule/TestPushJoinIntoTableScan.java | 45 +++++--
.../spi/connector/ConnectorMetadata.java | 47 +++++++
.../io/trino/spi/connector/JoinCondition.java | 93 ++++++++++++--
.../ClassLoaderSafeConnectorMetadata.java | 16 +++
10 files changed, 208 insertions(+), 129 deletions(-)
diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
index 5efd5b94e953..453826b74717 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java
@@ -31,7 +31,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
-import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
@@ -494,7 +493,7 @@ Optional> applyJoin(
JoinType joinType,
TableHandle left,
TableHandle right,
- List joinConditions,
+ ConnectorExpression joinCondition,
Map leftAssignments,
Map rightAssignments,
JoinStatistics statistics);
diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
index d4b1fb4f8e88..326c8a264b2a 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
@@ -62,7 +62,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
-import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
@@ -1619,7 +1618,7 @@ public Optional> applyJoin(
JoinType joinType,
TableHandle left,
TableHandle right,
- List joinConditions,
+ ConnectorExpression joinCondition,
Map leftAssignments,
Map rightAssignments,
JoinStatistics statistics)
@@ -1640,7 +1639,7 @@ public Optional> applyJoin(
joinType,
left.getConnectorHandle(),
right.getConnectorHandle(),
- joinConditions,
+ joinCondition,
leftAssignments,
rightAssignments,
statistics);
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
index b7f07d7ffea7..e93faab73995 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
@@ -836,7 +836,7 @@ public PlanOptimizers(
.addAll(pushIntoTableScanRulesExceptJoins)
// PushJoinIntoTableScan must run after ReorderJoins (and DetermineJoinDistributionType)
// otherwise too early pushdown could prevent optimal plan from being selected.
- .add(new PushJoinIntoTableScan(metadata))
+ .add(new PushJoinIntoTableScan(plannerContext, typeAnalyzer))
// DetermineTableScanNodePartitioning is needed to needs to ensure all table handles have proper partitioning determined
// Must run before AddExchanges
.add(new DetermineTableScanNodePartitioning(metadata, nodePartitioningManager, taskCountEstimator))
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java
index ff64fe9d2eba..541070c5ab92 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java
@@ -15,25 +15,24 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
-import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BasicRelationStatistics;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinApplicationResult;
-import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
-import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
-import io.trino.sql.ExpressionUtils;
+import io.trino.sql.PlannerContext;
+import io.trino.sql.planner.ConnectorExpressionTranslator;
+import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation;
import io.trino.sql.planner.Symbol;
+import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
@@ -43,14 +42,11 @@
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
-import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
-import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
@@ -59,7 +55,6 @@
import static io.trino.matching.Capture.newCapture;
import static io.trino.spi.predicate.Domain.onlyNull;
import static io.trino.sql.ExpressionUtils.and;
-import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown;
import static io.trino.sql.planner.plan.JoinNode.Type.FULL;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
@@ -81,11 +76,13 @@ public class PushJoinIntoTableScan
.with(left().matching(tableScan().capturedAs(LEFT_TABLE_SCAN)))
.with(right().matching(tableScan().capturedAs(RIGHT_TABLE_SCAN)));
- private final Metadata metadata;
+ private final PlannerContext plannerContext;
+ private final TypeAnalyzer typeAnalyzer;
- public PushJoinIntoTableScan(Metadata metadata)
+ public PushJoinIntoTableScan(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
{
- this.metadata = requireNonNull(metadata, "metadata is null");
+ this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
+ this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
}
@Override
@@ -113,9 +110,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan");
Expression effectiveFilter = getEffectiveFilter(joinNode);
- FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context);
+ ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts(
+ context.getSession(),
+ effectiveFilter,
+ context.getSymbolAllocator().getTypes(),
+ plannerContext,
+ typeAnalyzer);
- if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
+ if (!translation.remainingExpression().equals(BooleanLiteral.TRUE_LITERAL)) {
// TODO add extra filter node above join
return Result.empty();
}
@@ -144,13 +146,13 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
*/
JoinStatistics joinStatistics = getJoinStatistics(joinNode, left, right, context);
- Optional> joinApplicationResult = metadata.applyJoin(
+ Optional> joinApplicationResult = plannerContext.getMetadata().applyJoin(
context.getSession(),
getJoinType(joinNode),
left.getTable(),
right.getTable(),
- filterSplitResult.getPushableConditions(),
- // TODO we could pass only subset of assignments here, those which are needed to resolve filterSplitResult.getPushableConditions
+ translation.connectorExpression(),
+ // TODO we could pass only subset of assignments here, those which are needed to resolve translation.getPushableConditions
leftAssignments,
rightAssignments,
joinStatistics);
@@ -254,88 +256,6 @@ public Expression getEffectiveFilter(JoinNode node)
return effectiveFilter;
}
- private FilterSplitResult splitFilter(Expression filter, List leftSymbolsList, List rightSymbolsList, Context context)
- {
- Set leftSymbols = ImmutableSet.copyOf(leftSymbolsList);
- Set rightSymbols = ImmutableSet.copyOf(rightSymbolsList);
-
- ImmutableList.Builder comparisonConditions = ImmutableList.builder();
- ImmutableList.Builder remainingConjuncts = ImmutableList.builder();
-
- for (Expression conjunct : extractConjuncts(filter)) {
- getPushableJoinCondition(conjunct, leftSymbols, rightSymbols, context)
- .ifPresentOrElse(comparisonConditions::add, () -> remainingConjuncts.add(conjunct));
- }
-
- return new FilterSplitResult(comparisonConditions.build(), ExpressionUtils.and(remainingConjuncts.build()));
- }
-
- private Optional getPushableJoinCondition(Expression conjunct, Set leftSymbols, Set rightSymbols, Context context)
- {
- if (!(conjunct instanceof ComparisonExpression)) {
- return Optional.empty();
- }
- ComparisonExpression comparison = (ComparisonExpression) conjunct;
-
- if (!(comparison.getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference)) {
- return Optional.empty();
- }
- Symbol left = Symbol.from(comparison.getLeft());
- Symbol right = Symbol.from(comparison.getRight());
- ComparisonExpression.Operator operator = comparison.getOperator();
-
- if (!leftSymbols.contains(left)) {
- // lets try with flipped expression
- Symbol tmp = left;
- left = right;
- right = tmp;
- operator = operator.flip();
- }
-
- if (leftSymbols.contains(left) && rightSymbols.contains(right)) {
- return Optional.of(new JoinCondition(
- joinConditionOperator(operator),
- new Variable(left.getName(), context.getSymbolAllocator().getTypes().get(left)),
- new Variable(right.getName(), context.getSymbolAllocator().getTypes().get(right))));
- }
- return Optional.empty();
- }
-
- private static class FilterSplitResult
- {
- private final List pushableConditions;
- private final Expression remainingFilter;
-
- public FilterSplitResult(List pushableConditions, Expression remainingFilter)
- {
- this.pushableConditions = requireNonNull(pushableConditions, "pushableConditions is null");
- this.remainingFilter = requireNonNull(remainingFilter, "remainingFilter is null");
- }
-
- public List getPushableConditions()
- {
- return pushableConditions;
- }
-
- public Expression getRemainingFilter()
- {
- return remainingFilter;
- }
- }
-
- private JoinCondition.Operator joinConditionOperator(ComparisonExpression.Operator operator)
- {
- return switch (operator) {
- case EQUAL -> JoinCondition.Operator.EQUAL;
- case NOT_EQUAL -> JoinCondition.Operator.NOT_EQUAL;
- case LESS_THAN -> JoinCondition.Operator.LESS_THAN;
- case LESS_THAN_OR_EQUAL -> JoinCondition.Operator.LESS_THAN_OR_EQUAL;
- case GREATER_THAN -> JoinCondition.Operator.GREATER_THAN;
- case GREATER_THAN_OR_EQUAL -> JoinCondition.Operator.GREATER_THAN_OR_EQUAL;
- case IS_DISTINCT_FROM -> JoinCondition.Operator.IS_DISTINCT_FROM;
- };
- }
-
private JoinType getJoinType(JoinNode joinNode)
{
return switch (joinNode.getType()) {
diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java
index a22f5aba756f..b5bee784985d 100644
--- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java
+++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java
@@ -37,7 +37,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
-import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
@@ -611,7 +610,7 @@ public Optional> applyJoin(
JoinType joinType,
TableHandle left,
TableHandle right,
- List joinConditions,
+ ConnectorExpression joinCondition,
Map leftAssignments,
Map rightAssignments,
JoinStatistics statistics)
diff --git a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java
index d4dcee212246..6a9a87bc2553 100644
--- a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java
+++ b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java
@@ -33,7 +33,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
-import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
@@ -611,9 +610,9 @@ public Optional> applyAggregation(Sess
}
@Override
- public Optional> applyJoin(Session session, JoinType joinType, TableHandle left, TableHandle right, List joinConditions, Map leftAssignments, Map rightAssignments, JoinStatistics statistics)
+ public Optional> applyJoin(Session session, JoinType joinType, TableHandle left, TableHandle right, ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics)
{
- return delegate.applyJoin(session, joinType, left, right, joinConditions, leftAssignments, rightAssignments, statistics);
+ return delegate.applyJoin(session, joinType, left, right, joinCondition, leftAssignments, rightAssignments, statistics);
}
@Override
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java
index 0586ed1f87c5..e1bf4f6ff83b 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java
@@ -29,6 +29,8 @@
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
+import io.trino.spi.expression.Call;
+import io.trino.spi.expression.Constant;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.NullableValue;
@@ -51,6 +53,7 @@
import static com.google.common.base.Predicates.equalTo;
import static com.google.common.collect.ImmutableList.toImmutableList;
+import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME;
import static io.trino.spi.predicate.Domain.onlyNull;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
@@ -143,7 +146,7 @@ public void testPushJoinIntoTableScan(JoinNode.Type joinType, Optional {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -221,15 +224,33 @@ public static Object[][] testPushJoinIntoTableScanParams()
};
}
+ /**
+ * Test a scenario where join condition cannot be represented with simple comparisons.
+ */
@Test
- public void testPushJoinIntoTableScanDoesNotTriggerWithUnsupportedFilter()
+ public void testPushJoinIntoTableScanWithComplexFilter()
{
MockConnectorFactory connectorFactory = createMockConnectorFactory(
(session, applyJoinType, left, right, joinConditions, leftAssignments, rightAssignments) -> {
- throw new IllegalStateException("applyJoin should not be called!");
+ assertThat(joinConditions).as("joinConditions")
+ .isEqualTo(List.of(
+ new JoinCondition(
+ JoinCondition.Operator.GREATER_THAN,
+ new Call(
+ BIGINT,
+ MULTIPLY_FUNCTION_NAME,
+ List.of(
+ new Constant(44L, BIGINT),
+ new Variable("columna1", BIGINT))),
+ new Variable("columnb1", BIGINT))));
+ return Optional.of(new JoinApplicationResult<>(
+ JOIN_CONNECTOR_TABLE_HANDLE,
+ JOIN_TABLE_A_COLUMN_MAPPING,
+ JOIN_TABLE_B_COLUMN_MAPPING,
+ false));
});
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -255,7 +276,9 @@ public void testPushJoinIntoTableScanDoesNotTriggerWithUnsupportedFilter()
columnB1Symbol.toSymbolReference()));
})
.withSession(MOCK_SESSION)
- .doesNotFire();
+ .matches(
+ project(
+ tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName())));
}
}
@@ -270,7 +293,7 @@ public void testPushJoinIntoTableScanDoesNotFireForDifferentCatalogs()
ruleTester.getQueryRunner().createCatalog("another_catalog", "mock", ImmutableMap.of());
TableHandle tableBHandleAnotherCatalog = createTableHandle(new MockConnectorTableHandle(new SchemaTableName(SCHEMA, TABLE_B)), createTestCatalogHandle("another_catalog"));
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -309,7 +332,7 @@ public void testPushJoinIntoTableScanDoesNotFireWhenDisabled()
throw new IllegalStateException("applyJoin should not be called!");
});
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -348,7 +371,7 @@ public void testPushJoinIntoTableScanDoesNotFireWhenAllPushdownsDisabled()
throw new IllegalStateException("applyJoin should not be called!");
});
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -384,7 +407,7 @@ public void testPushJoinIntoTableScanPreservesEnforcedConstraint(JoinNode.Type j
JOIN_TABLE_B_COLUMN_MAPPING,
false)));
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -485,7 +508,7 @@ public void testPushJoinIntoTableDoesNotFireForCrossJoin()
throw new IllegalStateException("applyJoin should not be called!");
});
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
@@ -524,7 +547,7 @@ public void testPushJoinIntoTableRequiresFullColumnHandleMappingInResult()
false)));
try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(connectorFactory).build()) {
assertThatThrownBy(() -> {
- ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata()))
+ ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer()))
.on(p -> {
Symbol columnA1Symbol = p.symbol(COLUMN_A1);
Symbol columnA2Symbol = p.symbol(COLUMN_A2);
diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java
index aa0bc2f7a015..55ab1cfa1e13 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java
@@ -15,6 +15,7 @@
import io.airlift.slice.Slice;
import io.trino.spi.TrinoException;
+import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.Variable;
@@ -36,6 +37,7 @@
import javax.annotation.Nullable;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
@@ -49,6 +51,7 @@
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
+import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.stream.Collectors.toUnmodifiableList;
@@ -1282,6 +1285,50 @@ default Optional> applyAggreg
* It is required that mapping is provided for *all* column handles exposed previously by both left and right join sources.
*
*/
+ default Optional> applyJoin(
+ ConnectorSession session,
+ JoinType joinType,
+ ConnectorTableHandle left,
+ ConnectorTableHandle right,
+ ConnectorExpression joinCondition,
+ Map leftAssignments,
+ Map rightAssignments,
+ JoinStatistics statistics)
+ {
+ List conditions;
+ if (joinCondition instanceof Call call && AND_FUNCTION_NAME.equals(call.getFunctionName())) {
+ conditions = new ArrayList<>(call.getArguments().size());
+ for (ConnectorExpression argument : call.getArguments()) {
+ if (Constant.TRUE.equals(argument)) {
+ continue;
+ }
+ Optional condition = JoinCondition.from(argument, leftAssignments.keySet(), rightAssignments.keySet());
+ if (condition.isEmpty()) {
+ // We would need to add a FilterNode on top of the result
+ return Optional.empty();
+ }
+ conditions.add(condition.get());
+ }
+ }
+ else {
+ Optional condition = JoinCondition.from(joinCondition, leftAssignments.keySet(), rightAssignments.keySet());
+ if (condition.isEmpty()) {
+ return Optional.empty();
+ }
+ conditions = List.of(condition.get());
+ }
+ return applyJoin(
+ session,
+ joinType,
+ left,
+ right,
+ conditions,
+ leftAssignments,
+ rightAssignments,
+ statistics);
+ }
+
+ @Deprecated
default Optional> applyJoin(
ConnectorSession session,
JoinType joinType,
diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java
index 4d2b27c12e20..83989c1066cf 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java
@@ -13,36 +13,113 @@
*/
package io.trino.spi.connector;
+import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
+import io.trino.spi.expression.FunctionName;
+import io.trino.spi.expression.StandardFunctions;
+import io.trino.spi.expression.Variable;
+import java.util.ArrayDeque;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.Set;
+import java.util.stream.Stream;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toUnmodifiableMap;
+@Deprecated
public final class JoinCondition
{
+ @Deprecated
public enum Operator
{
- EQUAL("="),
- NOT_EQUAL("<>"),
- LESS_THAN("<"),
- LESS_THAN_OR_EQUAL("<="),
- GREATER_THAN(">"),
- GREATER_THAN_OR_EQUAL(">="),
- IS_DISTINCT_FROM("IS DISTINCT FROM");
+ EQUAL("=", StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME),
+ NOT_EQUAL("<>", StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME),
+ LESS_THAN("<", StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME),
+ LESS_THAN_OR_EQUAL("<=", StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME),
+ GREATER_THAN(">", StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME),
+ GREATER_THAN_OR_EQUAL(">=", StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME),
+ IS_DISTINCT_FROM("IS DISTINCT FROM", StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME),
+ /**/;
+
+ private static final Map byFunctionName = Stream.of(values())
+ .collect(toUnmodifiableMap(operator -> operator.callFunctionName, identity()));
private final String value;
+ private final FunctionName callFunctionName;
- Operator(String value)
+ Operator(String value, FunctionName callFunctionName)
{
this.value = value;
+ this.callFunctionName = callFunctionName;
}
public String getValue()
{
return value;
}
+
+ public Operator flip()
+ {
+ return switch (this) {
+ case EQUAL, NOT_EQUAL, IS_DISTINCT_FROM -> this;
+ case LESS_THAN -> GREATER_THAN;
+ case LESS_THAN_OR_EQUAL -> GREATER_THAN_OR_EQUAL;
+ case GREATER_THAN -> LESS_THAN;
+ case GREATER_THAN_OR_EQUAL -> LESS_THAN_OR_EQUAL;
+ };
+ }
+ }
+
+ public static Optional from(ConnectorExpression expression, Set leftSymbols, Set rightSymbols)
+ {
+ if (expression instanceof Call call && call.getArguments().size() == 2) {
+ return Optional.ofNullable(Operator.byFunctionName.get(call.getFunctionName()))
+ .flatMap(operator -> {
+ rightSymbols.stream().filter(leftSymbols::contains).findAny().ifPresent(symbol -> {
+ throw new IllegalArgumentException(
+ "Left and right symbol sets overlap, are both include %s: %s, %s".formatted(symbol, leftSymbols, rightSymbols));
+ });
+ ConnectorExpression left = call.getArguments().get(0);
+ ConnectorExpression right = call.getArguments().get(1);
+ Set leftExpressionSymbols = findVariableNames(left);
+ Set rightExpressionSymbols = findVariableNames(right);
+ if (leftSymbols.containsAll(leftExpressionSymbols) && rightSymbols.containsAll(rightExpressionSymbols)) {
+ return Optional.of(new JoinCondition(operator, left, right));
+ }
+ if (rightSymbols.containsAll(leftExpressionSymbols) && leftSymbols.containsAll(rightExpressionSymbols)) {
+ // normalize
+ return Optional.of(new JoinCondition(operator.flip(), right, left));
+ }
+ return Optional.empty();
+ });
+ }
+ return Optional.empty();
+ }
+
+ private static Set findVariableNames(ConnectorExpression expression)
+ {
+ Set variableNames = new HashSet<>();
+ Set visited = new HashSet<>();
+ Queue pending = new ArrayDeque<>(List.of(expression));
+ while (!pending.isEmpty()) {
+ ConnectorExpression next = pending.remove();
+ if (!visited.add(next)) {
+ continue;
+ }
+ pending.addAll(next.getChildren());
+ if (next instanceof Variable variable) {
+ variableNames.add(variable.getName());
+ }
+ }
+ return variableNames;
}
private final Operator operator;
diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java
index 5f0ec0cf1d3a..44a1ca238ed9 100644
--- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java
+++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java
@@ -908,6 +908,22 @@ public Optional> applyAggrega
}
}
+ @Override
+ public Optional> applyJoin(
+ ConnectorSession session,
+ JoinType joinType,
+ ConnectorTableHandle left,
+ ConnectorTableHandle right,
+ ConnectorExpression joinCondition,
+ Map leftAssignments,
+ Map rightAssignments,
+ JoinStatistics statistics)
+ {
+ try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
+ return delegate.applyJoin(session, joinType, left, right, joinCondition, leftAssignments, rightAssignments, statistics);
+ }
+ }
+
@Override
public Optional> applyJoin(
ConnectorSession session,