Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -117,14 +116,14 @@
* <li>SELECT ... FROM a, b WHERE 15.5 > ST_Distance(b.geometry, a.geometry)</li>
* </ul>
* <p>
* Joins expressed via ST_Contains and ST_Intersects functions must match all of
* Joins expressed via ST_Contains and ST_Intersects functions must match all
* the following criteria:
* <p>
* - arguments of the spatial function are non-scalar expressions;
* - one of the arguments uses symbols from left side of the join, the other from right.
* <p>
* Joins expressed via ST_Distance function must use less than or less than or equals operator
* to compare ST_Distance value with a radius and must match all of the following criteria:
* to compare ST_Distance value with a radius and must match all the following criteria:
* <p>
* - arguments of the spatial function are non-scalar expressions;
* - one of the arguments uses symbols from left side of the join, the other from right;
Expand Down Expand Up @@ -161,6 +160,53 @@ public class ExtractSpatialJoins
private final SplitManager splitManager;
private final PageSourceManager pageSourceManager;

private enum VariableSide
{
Neither,
Left,
Right,
Both
}

private static VariableSide inferVariableSide(RowExpression expression, JoinNode joinNode)
{
Set<VariableReferenceExpression> expressionVariables = extractUnique(expression);

if (expressionVariables.isEmpty()) {
return VariableSide.Neither;
}

List<VariableReferenceExpression> leftVariables = joinNode.getLeft().getOutputVariables();
List<VariableReferenceExpression> rightVariables = joinNode.getRight().getOutputVariables();
boolean leftContains = false;
boolean rightContains = false;
for (VariableReferenceExpression var : leftVariables) {
if (expressionVariables.contains(var)) {
leftContains = true;
break;
}
}
for (VariableReferenceExpression var : rightVariables) {
if (expressionVariables.contains(var)) {
rightContains = true;
break;
}
}

if (leftContains && rightContains) {
return VariableSide.Both;
}
else if (leftContains) {
return VariableSide.Left;
}
else if (rightContains) {
return VariableSide.Right;
}
else {
return VariableSide.Neither;
}
}

public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager)
{
this.metadata = requireNonNull(metadata, "metadata is null");
Expand Down Expand Up @@ -305,18 +351,18 @@ private static Result tryCreateSpatialJoin(
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();

List<VariableReferenceExpression> leftVariables = leftNode.getOutputVariables();
List<VariableReferenceExpression> rightVariables = rightNode.getOutputVariables();

RowExpression radius;
Optional<VariableReferenceExpression> newRadiusVariable;
VariableReferenceExpression radiusVariable;
CallExpression newComparison;
if (spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN || spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) {
// ST_Distance(a, b) <= r
radius = spatialComparison.getArguments().get(1);
Set<VariableReferenceExpression> radiusVariables = extractUnique(radius);
if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) {
newRadiusVariable = newRadiusVariable(context, radius);
VariableSide radiusSide = inferVariableSide(radius, joinNode);
if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: The logic for handling radiusSide may miss edge cases.

The current logic does not account for the 'Both' case in radiusSide, which may result in incorrect handling when variables from both sides are present. Please update the logic to address this scenario.

newRadiusVariable = newVariable(context, radius);
// If newRadiusVariable is empty, radius is VRE
radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius);
Comment on lines +356 to +365
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Casting radius to VariableReferenceExpression may be unsafe.

If newRadiusVariable is empty and radius is not a VariableReferenceExpression, a ClassCastException may occur. Add a type check or assertion to prevent this.

newComparison = new CallExpression(
spatialComparison.getSourceLocation(),
spatialComparison.getDisplayName(),
Expand All @@ -331,9 +377,11 @@ private static Result tryCreateSpatialJoin(
else {
// r >= ST_Distance(a, b)
radius = spatialComparison.getArguments().get(0);
Set<VariableReferenceExpression> radiusVariables = extractUnique(radius);
if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) {
newRadiusVariable = newRadiusVariable(context, radius);
VariableSide radiusSide = inferVariableSide(radius, joinNode);
if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) {
newRadiusVariable = newVariable(context, radius);
// If newRadiusVariable is empty, radius is VRE
radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius);
OperatorType flippedOperatorType = flip(spatialComparisonMetadata.getOperatorType().get());
FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager());
newComparison = new CallExpression(
Expand Down Expand Up @@ -365,7 +413,7 @@ private static Result tryCreateSpatialJoin(
joinNode.getDistributionType(),
joinNode.getDynamicFilters());

return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager);
return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(radiusVariable), metadata, splitManager, pageSourceManager);
}

private static Result tryCreateSpatialJoin(
Expand All @@ -375,7 +423,7 @@ private static Result tryCreateSpatialJoin(
PlanNodeId nodeId,
List<VariableReferenceExpression> outputVariables,
CallExpression spatialFunction,
Optional<RowExpression> radius,
Optional<VariableReferenceExpression> radius,
Metadata metadata,
SplitManager splitManager,
PageSourceManager pageSourceManager)
Expand All @@ -393,37 +441,40 @@ private static Result tryCreateSpatialJoin(
return Result.empty();
}

Set<VariableReferenceExpression> firstVariables = extractUnique(firstArgument);
Set<VariableReferenceExpression> secondVariables = extractUnique(secondArgument);

if (firstVariables.isEmpty() || secondVariables.isEmpty()) {
VariableSide firstSide = inferVariableSide(firstArgument, joinNode);
VariableSide secondSide = inferVariableSide(secondArgument, joinNode);
boolean firstArgumentOnLeft;
if (firstSide == VariableSide.Left && secondSide == VariableSide.Right) {
firstArgumentOnLeft = true;
}
else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) {
firstArgumentOnLeft = false;
}
else {
// Spatial joins require each argument comes from only one side, and they come from opposite sides
return Result.empty();
}

// If either firstArgument or secondArgument is not a
// VariableReferenceExpression, will replace the left/right join node
// with a projection that adds the argument as a variable.
Optional<VariableReferenceExpression> newFirstVariable = newGeometryVariable(context, firstArgument);
Optional<VariableReferenceExpression> newSecondVariable = newGeometryVariable(context, secondArgument);
Optional<VariableReferenceExpression> newFirstVariable = newVariable(context, firstArgument);
Optional<VariableReferenceExpression> newSecondVariable = newVariable(context, secondArgument);

PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
PlanNode newLeftNode;
PlanNode newRightNode;

// Check if the order of arguments of the spatial function matches the order of join sides
int alignment = checkAlignment(joinNode, firstVariables, secondVariables);
if (alignment > 0) {
if (firstArgumentOnLeft) {
newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode);
newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode);
}
else if (alignment < 0) {
else {
newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode);
newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode);
}
else {
return Result.empty();
}

RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument);
RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument);
Expand All @@ -441,7 +492,7 @@ else if (alignment < 0) {
leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER));
rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER));

if (alignment > 0) {
if (firstArgumentOnLeft) {
newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty());
newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius);
}
Expand Down Expand Up @@ -601,43 +652,12 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat
throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name));
}

private static int checkAlignment(JoinNode joinNode, Set<VariableReferenceExpression> maybeLeftVariables, Set<VariableReferenceExpression> maybeRightVariables)
{
List<VariableReferenceExpression> leftVariables = joinNode.getLeft().getOutputVariables();
List<VariableReferenceExpression> rightVariables = joinNode.getRight().getOutputVariables();

if (leftVariables.containsAll(maybeLeftVariables)
&& containsNone(leftVariables, maybeRightVariables)
&& rightVariables.containsAll(maybeRightVariables)
&& containsNone(rightVariables, maybeLeftVariables)) {
return 1;
}

if (leftVariables.containsAll(maybeRightVariables)
&& containsNone(leftVariables, maybeLeftVariables)
&& rightVariables.containsAll(maybeLeftVariables)
&& containsNone(rightVariables, maybeRightVariables)) {
return -1;
}

return 0;
}

private static RowExpression mapToExpression(Optional<VariableReferenceExpression> optionalVariable, RowExpression defaultExpression)
{
return optionalVariable.map(RowExpression.class::cast).orElse(defaultExpression);
}

private static Optional<VariableReferenceExpression> newGeometryVariable(Context context, RowExpression expression)
{
if (expression instanceof VariableReferenceExpression) {
return Optional.empty();
}

return Optional.of(context.getVariableAllocator().newVariable(expression));
}

private static Optional<VariableReferenceExpression> newRadiusVariable(Context context, RowExpression expression)
private static Optional<VariableReferenceExpression> newVariable(Context context, RowExpression expression)
{
if (expression instanceof VariableReferenceExpression) {
return Optional.empty();
Expand All @@ -657,7 +677,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe
return new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node, projections.build(), LOCAL);
}

private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional<RowExpression> radius)
private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional<VariableReferenceExpression> radius)
{
Assignments.Builder projections = Assignments.builder();
for (VariableReferenceExpression outputVariable : node.getOutputVariables()) {
Expand Down Expand Up @@ -690,9 +710,4 @@ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeMan
ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)),
Optional.empty());
}

private static boolean containsNone(Collection<VariableReferenceExpression> values, Collection<VariableReferenceExpression> testValues)
{
return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains);
}
}
Loading