-
Notifications
You must be signed in to change notification settings - Fork 5.5k
refactor: Refactor ExtractSpatialJoin #26250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+80
−65
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,6 @@ | |
|
|
||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.util.Collection; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Optional; | ||
|
|
@@ -117,14 +116,14 @@ | |
| * <li>SELECT ... FROM a, b WHERE 15.5 > ST_Distance(b.geometry, a.geometry)</li> | ||
| * </ul> | ||
| * <p> | ||
| * Joins expressed via ST_Contains and ST_Intersects functions must match all of | ||
| * Joins expressed via ST_Contains and ST_Intersects functions must match all | ||
| * the following criteria: | ||
| * <p> | ||
| * - arguments of the spatial function are non-scalar expressions; | ||
| * - one of the arguments uses symbols from left side of the join, the other from right. | ||
| * <p> | ||
| * Joins expressed via ST_Distance function must use less than or less than or equals operator | ||
| * to compare ST_Distance value with a radius and must match all of the following criteria: | ||
| * to compare ST_Distance value with a radius and must match all the following criteria: | ||
| * <p> | ||
| * - arguments of the spatial function are non-scalar expressions; | ||
| * - one of the arguments uses symbols from left side of the join, the other from right; | ||
|
|
@@ -161,6 +160,53 @@ public class ExtractSpatialJoins | |
| private final SplitManager splitManager; | ||
| private final PageSourceManager pageSourceManager; | ||
|
|
||
| private enum VariableSide | ||
| { | ||
| Neither, | ||
| Left, | ||
| Right, | ||
| Both | ||
| } | ||
|
|
||
| private static VariableSide inferVariableSide(RowExpression expression, JoinNode joinNode) | ||
| { | ||
| Set<VariableReferenceExpression> expressionVariables = extractUnique(expression); | ||
|
|
||
| if (expressionVariables.isEmpty()) { | ||
| return VariableSide.Neither; | ||
| } | ||
|
|
||
| List<VariableReferenceExpression> leftVariables = joinNode.getLeft().getOutputVariables(); | ||
| List<VariableReferenceExpression> rightVariables = joinNode.getRight().getOutputVariables(); | ||
| boolean leftContains = false; | ||
| boolean rightContains = false; | ||
| for (VariableReferenceExpression var : leftVariables) { | ||
| if (expressionVariables.contains(var)) { | ||
| leftContains = true; | ||
| break; | ||
| } | ||
| } | ||
| for (VariableReferenceExpression var : rightVariables) { | ||
| if (expressionVariables.contains(var)) { | ||
| rightContains = true; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (leftContains && rightContains) { | ||
| return VariableSide.Both; | ||
| } | ||
| else if (leftContains) { | ||
| return VariableSide.Left; | ||
| } | ||
| else if (rightContains) { | ||
| return VariableSide.Right; | ||
| } | ||
| else { | ||
| return VariableSide.Neither; | ||
| } | ||
| } | ||
|
|
||
| public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) | ||
| { | ||
| this.metadata = requireNonNull(metadata, "metadata is null"); | ||
|
|
@@ -305,18 +351,18 @@ private static Result tryCreateSpatialJoin( | |
| PlanNode leftNode = joinNode.getLeft(); | ||
| PlanNode rightNode = joinNode.getRight(); | ||
|
|
||
| List<VariableReferenceExpression> leftVariables = leftNode.getOutputVariables(); | ||
| List<VariableReferenceExpression> rightVariables = rightNode.getOutputVariables(); | ||
|
|
||
| RowExpression radius; | ||
| Optional<VariableReferenceExpression> newRadiusVariable; | ||
| VariableReferenceExpression radiusVariable; | ||
| CallExpression newComparison; | ||
| if (spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN || spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) { | ||
| // ST_Distance(a, b) <= r | ||
| radius = spatialComparison.getArguments().get(1); | ||
| Set<VariableReferenceExpression> radiusVariables = extractUnique(radius); | ||
| if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { | ||
| newRadiusVariable = newRadiusVariable(context, radius); | ||
| VariableSide radiusSide = inferVariableSide(radius, joinNode); | ||
| if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { | ||
| newRadiusVariable = newVariable(context, radius); | ||
| // If newRadiusVariable is empty, radius is VRE | ||
| radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); | ||
|
Comment on lines
+356
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Casting radius to VariableReferenceExpression may be unsafe. If newRadiusVariable is empty and radius is not a VariableReferenceExpression, a ClassCastException may occur. Add a type check or assertion to prevent this. |
||
| newComparison = new CallExpression( | ||
| spatialComparison.getSourceLocation(), | ||
| spatialComparison.getDisplayName(), | ||
|
|
@@ -331,9 +377,11 @@ private static Result tryCreateSpatialJoin( | |
| else { | ||
| // r >= ST_Distance(a, b) | ||
| radius = spatialComparison.getArguments().get(0); | ||
| Set<VariableReferenceExpression> radiusVariables = extractUnique(radius); | ||
| if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { | ||
| newRadiusVariable = newRadiusVariable(context, radius); | ||
| VariableSide radiusSide = inferVariableSide(radius, joinNode); | ||
| if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { | ||
| newRadiusVariable = newVariable(context, radius); | ||
| // If newRadiusVariable is empty, radius is VRE | ||
| radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); | ||
| OperatorType flippedOperatorType = flip(spatialComparisonMetadata.getOperatorType().get()); | ||
| FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager()); | ||
| newComparison = new CallExpression( | ||
|
|
@@ -365,7 +413,7 @@ private static Result tryCreateSpatialJoin( | |
| joinNode.getDistributionType(), | ||
| joinNode.getDynamicFilters()); | ||
|
|
||
| return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager); | ||
| return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(radiusVariable), metadata, splitManager, pageSourceManager); | ||
| } | ||
|
|
||
| private static Result tryCreateSpatialJoin( | ||
|
|
@@ -375,7 +423,7 @@ private static Result tryCreateSpatialJoin( | |
| PlanNodeId nodeId, | ||
| List<VariableReferenceExpression> outputVariables, | ||
| CallExpression spatialFunction, | ||
| Optional<RowExpression> radius, | ||
| Optional<VariableReferenceExpression> radius, | ||
| Metadata metadata, | ||
| SplitManager splitManager, | ||
| PageSourceManager pageSourceManager) | ||
|
|
@@ -393,37 +441,40 @@ private static Result tryCreateSpatialJoin( | |
| return Result.empty(); | ||
| } | ||
|
|
||
| Set<VariableReferenceExpression> firstVariables = extractUnique(firstArgument); | ||
| Set<VariableReferenceExpression> secondVariables = extractUnique(secondArgument); | ||
|
|
||
| if (firstVariables.isEmpty() || secondVariables.isEmpty()) { | ||
| VariableSide firstSide = inferVariableSide(firstArgument, joinNode); | ||
| VariableSide secondSide = inferVariableSide(secondArgument, joinNode); | ||
| boolean firstArgumentOnLeft; | ||
| if (firstSide == VariableSide.Left && secondSide == VariableSide.Right) { | ||
| firstArgumentOnLeft = true; | ||
| } | ||
| else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { | ||
| firstArgumentOnLeft = false; | ||
| } | ||
| else { | ||
| // Spatial joins require each argument comes from only one side, and they come from opposite sides | ||
| return Result.empty(); | ||
| } | ||
|
|
||
| // If either firstArgument or secondArgument is not a | ||
| // VariableReferenceExpression, will replace the left/right join node | ||
| // with a projection that adds the argument as a variable. | ||
| Optional<VariableReferenceExpression> newFirstVariable = newGeometryVariable(context, firstArgument); | ||
| Optional<VariableReferenceExpression> newSecondVariable = newGeometryVariable(context, secondArgument); | ||
| Optional<VariableReferenceExpression> newFirstVariable = newVariable(context, firstArgument); | ||
| Optional<VariableReferenceExpression> newSecondVariable = newVariable(context, secondArgument); | ||
|
|
||
| PlanNode leftNode = joinNode.getLeft(); | ||
| PlanNode rightNode = joinNode.getRight(); | ||
| PlanNode newLeftNode; | ||
| PlanNode newRightNode; | ||
|
|
||
| // Check if the order of arguments of the spatial function matches the order of join sides | ||
| int alignment = checkAlignment(joinNode, firstVariables, secondVariables); | ||
| if (alignment > 0) { | ||
| if (firstArgumentOnLeft) { | ||
| newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); | ||
| newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); | ||
| } | ||
| else if (alignment < 0) { | ||
| else { | ||
| newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); | ||
| newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); | ||
| } | ||
| else { | ||
| return Result.empty(); | ||
| } | ||
|
|
||
| RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument); | ||
| RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument); | ||
|
|
@@ -441,7 +492,7 @@ else if (alignment < 0) { | |
| leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER)); | ||
| rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER)); | ||
|
|
||
| if (alignment > 0) { | ||
| if (firstArgumentOnLeft) { | ||
| newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty()); | ||
| newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius); | ||
| } | ||
|
|
@@ -601,43 +652,12 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat | |
| throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name)); | ||
| } | ||
|
|
||
| private static int checkAlignment(JoinNode joinNode, Set<VariableReferenceExpression> maybeLeftVariables, Set<VariableReferenceExpression> maybeRightVariables) | ||
| { | ||
| List<VariableReferenceExpression> leftVariables = joinNode.getLeft().getOutputVariables(); | ||
| List<VariableReferenceExpression> rightVariables = joinNode.getRight().getOutputVariables(); | ||
|
|
||
| if (leftVariables.containsAll(maybeLeftVariables) | ||
| && containsNone(leftVariables, maybeRightVariables) | ||
| && rightVariables.containsAll(maybeRightVariables) | ||
| && containsNone(rightVariables, maybeLeftVariables)) { | ||
| return 1; | ||
| } | ||
|
|
||
| if (leftVariables.containsAll(maybeRightVariables) | ||
| && containsNone(leftVariables, maybeLeftVariables) | ||
| && rightVariables.containsAll(maybeLeftVariables) | ||
| && containsNone(rightVariables, maybeRightVariables)) { | ||
| return -1; | ||
| } | ||
|
|
||
| return 0; | ||
| } | ||
|
|
||
| private static RowExpression mapToExpression(Optional<VariableReferenceExpression> optionalVariable, RowExpression defaultExpression) | ||
| { | ||
| return optionalVariable.map(RowExpression.class::cast).orElse(defaultExpression); | ||
| } | ||
|
|
||
| private static Optional<VariableReferenceExpression> newGeometryVariable(Context context, RowExpression expression) | ||
| { | ||
| if (expression instanceof VariableReferenceExpression) { | ||
| return Optional.empty(); | ||
| } | ||
|
|
||
| return Optional.of(context.getVariableAllocator().newVariable(expression)); | ||
| } | ||
|
|
||
| private static Optional<VariableReferenceExpression> newRadiusVariable(Context context, RowExpression expression) | ||
| private static Optional<VariableReferenceExpression> newVariable(Context context, RowExpression expression) | ||
| { | ||
| if (expression instanceof VariableReferenceExpression) { | ||
| return Optional.empty(); | ||
|
|
@@ -657,7 +677,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe | |
| return new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node, projections.build(), LOCAL); | ||
| } | ||
|
|
||
| private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional<RowExpression> radius) | ||
| private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional<VariableReferenceExpression> radius) | ||
| { | ||
| Assignments.Builder projections = Assignments.builder(); | ||
| for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { | ||
|
|
@@ -690,9 +710,4 @@ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeMan | |
| ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)), | ||
| Optional.empty()); | ||
| } | ||
|
|
||
| private static boolean containsNone(Collection<VariableReferenceExpression> values, Collection<VariableReferenceExpression> testValues) | ||
| { | ||
| return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: The logic for handling radiusSide may miss edge cases.
The current logic does not account for the 'Both' case in radiusSide, which may result in incorrect handling when variables from both sides are present. Please update the logic to address this scenario.