diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index 4b30f9de7f9aa..93fa95932846e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -67,7 +67,6 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -117,14 +116,14 @@ *
  • SELECT ... FROM a, b WHERE 15.5 > ST_Distance(b.geometry, a.geometry)
  • * *

    - * Joins expressed via ST_Contains and ST_Intersects functions must match all of + * Joins expressed via ST_Contains and ST_Intersects functions must match all * the following criteria: *

    * - arguments of the spatial function are non-scalar expressions; * - one of the arguments uses symbols from left side of the join, the other from right. *

    * Joins expressed via ST_Distance function must use less than or less than or equals operator - * to compare ST_Distance value with a radius and must match all of the following criteria: + * to compare ST_Distance value with a radius and must match all the following criteria: *

    * - arguments of the spatial function are non-scalar expressions; * - one of the arguments uses symbols from left side of the join, the other from right; @@ -161,6 +160,53 @@ public class ExtractSpatialJoins private final SplitManager splitManager; private final PageSourceManager pageSourceManager; + private enum VariableSide + { + Neither, + Left, + Right, + Both + } + + private static VariableSide inferVariableSide(RowExpression expression, JoinNode joinNode) + { + Set expressionVariables = extractUnique(expression); + + if (expressionVariables.isEmpty()) { + return VariableSide.Neither; + } + + List leftVariables = joinNode.getLeft().getOutputVariables(); + List rightVariables = joinNode.getRight().getOutputVariables(); + boolean leftContains = false; + boolean rightContains = false; + for (VariableReferenceExpression var : leftVariables) { + if (expressionVariables.contains(var)) { + leftContains = true; + break; + } + } + for (VariableReferenceExpression var : rightVariables) { + if (expressionVariables.contains(var)) { + rightContains = true; + break; + } + } + + if (leftContains && rightContains) { + return VariableSide.Both; + } + else if (leftContains) { + return VariableSide.Left; + } + else if (rightContains) { + return VariableSide.Right; + } + else { + return VariableSide.Neither; + } + } + public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) { this.metadata = requireNonNull(metadata, "metadata is null"); @@ -305,18 +351,18 @@ private static Result tryCreateSpatialJoin( PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); - List leftVariables = leftNode.getOutputVariables(); - List rightVariables = rightNode.getOutputVariables(); - RowExpression radius; Optional newRadiusVariable; + VariableReferenceExpression radiusVariable; CallExpression newComparison; if (spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN || spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r radius = spatialComparison.getArguments().get(1); - Set radiusVariables = extractUnique(radius); - if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { - newRadiusVariable = newRadiusVariable(context, radius); + VariableSide radiusSide = inferVariableSide(radius, joinNode); + if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { + newRadiusVariable = newVariable(context, radius); + // If newRadiusVariable is empty, radius is VRE + radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); newComparison = new CallExpression( spatialComparison.getSourceLocation(), spatialComparison.getDisplayName(), @@ -331,9 +377,11 @@ private static Result tryCreateSpatialJoin( else { // r >= ST_Distance(a, b) radius = spatialComparison.getArguments().get(0); - Set radiusVariables = extractUnique(radius); - if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { - newRadiusVariable = newRadiusVariable(context, radius); + VariableSide radiusSide = inferVariableSide(radius, joinNode); + if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { + newRadiusVariable = newVariable(context, radius); + // If newRadiusVariable is empty, radius is VRE + radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); OperatorType flippedOperatorType = flip(spatialComparisonMetadata.getOperatorType().get()); FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager()); newComparison = new CallExpression( @@ -365,7 +413,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDistributionType(), joinNode.getDynamicFilters()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(radiusVariable), metadata, splitManager, pageSourceManager); } private static Result tryCreateSpatialJoin( @@ -375,7 +423,7 @@ private static Result tryCreateSpatialJoin( PlanNodeId nodeId, List outputVariables, CallExpression spatialFunction, - Optional radius, + Optional radius, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) @@ -393,18 +441,25 @@ private static Result tryCreateSpatialJoin( return Result.empty(); } - Set firstVariables = extractUnique(firstArgument); - Set secondVariables = extractUnique(secondArgument); - - if (firstVariables.isEmpty() || secondVariables.isEmpty()) { + VariableSide firstSide = inferVariableSide(firstArgument, joinNode); + VariableSide secondSide = inferVariableSide(secondArgument, joinNode); + boolean firstArgumentOnLeft; + if (firstSide == VariableSide.Left && secondSide == VariableSide.Right) { + firstArgumentOnLeft = true; + } + else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { + firstArgumentOnLeft = false; + } + else { + // Spatial joins require each argument comes from only one side, and they come from opposite sides return Result.empty(); } // If either firstArgument or secondArgument is not a // VariableReferenceExpression, will replace the left/right join node // with a projection that adds the argument as a variable. - Optional newFirstVariable = newGeometryVariable(context, firstArgument); - Optional newSecondVariable = newGeometryVariable(context, secondArgument); + Optional newFirstVariable = newVariable(context, firstArgument); + Optional newSecondVariable = newVariable(context, secondArgument); PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -412,18 +467,14 @@ private static Result tryCreateSpatialJoin( PlanNode newRightNode; // Check if the order of arguments of the spatial function matches the order of join sides - int alignment = checkAlignment(joinNode, firstVariables, secondVariables); - if (alignment > 0) { + if (firstArgumentOnLeft) { newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); } - else if (alignment < 0) { + else { newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); } - else { - return Result.empty(); - } RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument); RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument); @@ -441,7 +492,7 @@ else if (alignment < 0) { leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER)); rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER)); - if (alignment > 0) { + if (firstArgumentOnLeft) { newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty()); newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius); } @@ -601,43 +652,12 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name)); } - private static int checkAlignment(JoinNode joinNode, Set maybeLeftVariables, Set maybeRightVariables) - { - List leftVariables = joinNode.getLeft().getOutputVariables(); - List rightVariables = joinNode.getRight().getOutputVariables(); - - if (leftVariables.containsAll(maybeLeftVariables) - && containsNone(leftVariables, maybeRightVariables) - && rightVariables.containsAll(maybeRightVariables) - && containsNone(rightVariables, maybeLeftVariables)) { - return 1; - } - - if (leftVariables.containsAll(maybeRightVariables) - && containsNone(leftVariables, maybeLeftVariables) - && rightVariables.containsAll(maybeLeftVariables) - && containsNone(rightVariables, maybeRightVariables)) { - return -1; - } - - return 0; - } - private static RowExpression mapToExpression(Optional optionalVariable, RowExpression defaultExpression) { return optionalVariable.map(RowExpression.class::cast).orElse(defaultExpression); } - private static Optional newGeometryVariable(Context context, RowExpression expression) - { - if (expression instanceof VariableReferenceExpression) { - return Optional.empty(); - } - - return Optional.of(context.getVariableAllocator().newVariable(expression)); - } - - private static Optional newRadiusVariable(Context context, RowExpression expression) + private static Optional newVariable(Context context, RowExpression expression) { if (expression instanceof VariableReferenceExpression) { return Optional.empty(); @@ -657,7 +677,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe return new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node, projections.build(), LOCAL); } - private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius) + private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius) { Assignments.Builder projections = Assignments.builder(); for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { @@ -690,9 +710,4 @@ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeMan ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)), Optional.empty()); } - - private static boolean containsNone(Collection values, Collection testValues) - { - return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); - } }