diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java
index 4b30f9de7f9aa..93fa95932846e 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java
@@ -67,7 +67,6 @@
import java.io.IOException;
import java.io.UncheckedIOException;
-import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -117,14 +116,14 @@
*
SELECT ... FROM a, b WHERE 15.5 > ST_Distance(b.geometry, a.geometry)
*
*
- * Joins expressed via ST_Contains and ST_Intersects functions must match all of
+ * Joins expressed via ST_Contains and ST_Intersects functions must match all
* the following criteria:
*
* - arguments of the spatial function are non-scalar expressions;
* - one of the arguments uses symbols from left side of the join, the other from right.
*
* Joins expressed via ST_Distance function must use less than or less than or equals operator
- * to compare ST_Distance value with a radius and must match all of the following criteria:
+ * to compare ST_Distance value with a radius and must match all the following criteria:
*
* - arguments of the spatial function are non-scalar expressions;
* - one of the arguments uses symbols from left side of the join, the other from right;
@@ -161,6 +160,53 @@ public class ExtractSpatialJoins
private final SplitManager splitManager;
private final PageSourceManager pageSourceManager;
+ private enum VariableSide
+ {
+ Neither,
+ Left,
+ Right,
+ Both
+ }
+
+ private static VariableSide inferVariableSide(RowExpression expression, JoinNode joinNode)
+ {
+ Set expressionVariables = extractUnique(expression);
+
+ if (expressionVariables.isEmpty()) {
+ return VariableSide.Neither;
+ }
+
+ List leftVariables = joinNode.getLeft().getOutputVariables();
+ List rightVariables = joinNode.getRight().getOutputVariables();
+ boolean leftContains = false;
+ boolean rightContains = false;
+ for (VariableReferenceExpression var : leftVariables) {
+ if (expressionVariables.contains(var)) {
+ leftContains = true;
+ break;
+ }
+ }
+ for (VariableReferenceExpression var : rightVariables) {
+ if (expressionVariables.contains(var)) {
+ rightContains = true;
+ break;
+ }
+ }
+
+ if (leftContains && rightContains) {
+ return VariableSide.Both;
+ }
+ else if (leftContains) {
+ return VariableSide.Left;
+ }
+ else if (rightContains) {
+ return VariableSide.Right;
+ }
+ else {
+ return VariableSide.Neither;
+ }
+ }
+
public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager)
{
this.metadata = requireNonNull(metadata, "metadata is null");
@@ -305,18 +351,18 @@ private static Result tryCreateSpatialJoin(
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
- List leftVariables = leftNode.getOutputVariables();
- List rightVariables = rightNode.getOutputVariables();
-
RowExpression radius;
Optional newRadiusVariable;
+ VariableReferenceExpression radiusVariable;
CallExpression newComparison;
if (spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN || spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) {
// ST_Distance(a, b) <= r
radius = spatialComparison.getArguments().get(1);
- Set radiusVariables = extractUnique(radius);
- if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) {
- newRadiusVariable = newRadiusVariable(context, radius);
+ VariableSide radiusSide = inferVariableSide(radius, joinNode);
+ if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) {
+ newRadiusVariable = newVariable(context, radius);
+ // If newRadiusVariable is empty, radius is VRE
+ radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius);
newComparison = new CallExpression(
spatialComparison.getSourceLocation(),
spatialComparison.getDisplayName(),
@@ -331,9 +377,11 @@ private static Result tryCreateSpatialJoin(
else {
// r >= ST_Distance(a, b)
radius = spatialComparison.getArguments().get(0);
- Set radiusVariables = extractUnique(radius);
- if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) {
- newRadiusVariable = newRadiusVariable(context, radius);
+ VariableSide radiusSide = inferVariableSide(radius, joinNode);
+ if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) {
+ newRadiusVariable = newVariable(context, radius);
+ // If newRadiusVariable is empty, radius is VRE
+ radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius);
OperatorType flippedOperatorType = flip(spatialComparisonMetadata.getOperatorType().get());
FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager());
newComparison = new CallExpression(
@@ -365,7 +413,7 @@ private static Result tryCreateSpatialJoin(
joinNode.getDistributionType(),
joinNode.getDynamicFilters());
- return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager);
+ return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(radiusVariable), metadata, splitManager, pageSourceManager);
}
private static Result tryCreateSpatialJoin(
@@ -375,7 +423,7 @@ private static Result tryCreateSpatialJoin(
PlanNodeId nodeId,
List outputVariables,
CallExpression spatialFunction,
- Optional radius,
+ Optional radius,
Metadata metadata,
SplitManager splitManager,
PageSourceManager pageSourceManager)
@@ -393,18 +441,25 @@ private static Result tryCreateSpatialJoin(
return Result.empty();
}
- Set firstVariables = extractUnique(firstArgument);
- Set secondVariables = extractUnique(secondArgument);
-
- if (firstVariables.isEmpty() || secondVariables.isEmpty()) {
+ VariableSide firstSide = inferVariableSide(firstArgument, joinNode);
+ VariableSide secondSide = inferVariableSide(secondArgument, joinNode);
+ boolean firstArgumentOnLeft;
+ if (firstSide == VariableSide.Left && secondSide == VariableSide.Right) {
+ firstArgumentOnLeft = true;
+ }
+ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) {
+ firstArgumentOnLeft = false;
+ }
+ else {
+ // Spatial joins require each argument comes from only one side, and they come from opposite sides
return Result.empty();
}
// If either firstArgument or secondArgument is not a
// VariableReferenceExpression, will replace the left/right join node
// with a projection that adds the argument as a variable.
- Optional newFirstVariable = newGeometryVariable(context, firstArgument);
- Optional newSecondVariable = newGeometryVariable(context, secondArgument);
+ Optional newFirstVariable = newVariable(context, firstArgument);
+ Optional newSecondVariable = newVariable(context, secondArgument);
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
@@ -412,18 +467,14 @@ private static Result tryCreateSpatialJoin(
PlanNode newRightNode;
// Check if the order of arguments of the spatial function matches the order of join sides
- int alignment = checkAlignment(joinNode, firstVariables, secondVariables);
- if (alignment > 0) {
+ if (firstArgumentOnLeft) {
newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode);
newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode);
}
- else if (alignment < 0) {
+ else {
newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode);
newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode);
}
- else {
- return Result.empty();
- }
RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument);
RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument);
@@ -441,7 +492,7 @@ else if (alignment < 0) {
leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER));
rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER));
- if (alignment > 0) {
+ if (firstArgumentOnLeft) {
newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty());
newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius);
}
@@ -601,43 +652,12 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat
throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name));
}
- private static int checkAlignment(JoinNode joinNode, Set maybeLeftVariables, Set maybeRightVariables)
- {
- List leftVariables = joinNode.getLeft().getOutputVariables();
- List rightVariables = joinNode.getRight().getOutputVariables();
-
- if (leftVariables.containsAll(maybeLeftVariables)
- && containsNone(leftVariables, maybeRightVariables)
- && rightVariables.containsAll(maybeRightVariables)
- && containsNone(rightVariables, maybeLeftVariables)) {
- return 1;
- }
-
- if (leftVariables.containsAll(maybeRightVariables)
- && containsNone(leftVariables, maybeLeftVariables)
- && rightVariables.containsAll(maybeLeftVariables)
- && containsNone(rightVariables, maybeRightVariables)) {
- return -1;
- }
-
- return 0;
- }
-
private static RowExpression mapToExpression(Optional optionalVariable, RowExpression defaultExpression)
{
return optionalVariable.map(RowExpression.class::cast).orElse(defaultExpression);
}
- private static Optional newGeometryVariable(Context context, RowExpression expression)
- {
- if (expression instanceof VariableReferenceExpression) {
- return Optional.empty();
- }
-
- return Optional.of(context.getVariableAllocator().newVariable(expression));
- }
-
- private static Optional newRadiusVariable(Context context, RowExpression expression)
+ private static Optional newVariable(Context context, RowExpression expression)
{
if (expression instanceof VariableReferenceExpression) {
return Optional.empty();
@@ -657,7 +677,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe
return new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node, projections.build(), LOCAL);
}
- private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius)
+ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius)
{
Assignments.Builder projections = Assignments.builder();
for (VariableReferenceExpression outputVariable : node.getOutputVariables()) {
@@ -690,9 +710,4 @@ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeMan
ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)),
Optional.empty());
}
-
- private static boolean containsNone(Collection values, Collection testValues)
- {
- return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains);
- }
}