diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index 93fa95932846e..db28daede971c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -460,6 +460,8 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { // with a projection that adds the argument as a variable. Optional newFirstVariable = newVariable(context, firstArgument); Optional newSecondVariable = newVariable(context, secondArgument); + VariableReferenceExpression leftGeometryVariable; + VariableReferenceExpression rightGeometryVariable; PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -470,10 +472,16 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { if (firstArgumentOnLeft) { newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); + rightGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); } else { newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); + rightGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); } RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument); @@ -512,6 +520,9 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { newLeftNode, newRightNode, outputVariables, + leftGeometryVariable, + rightGeometryVariable, + radius, newFilter, leftPartitionVariable, rightPartitionVariable, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index e969326ecfff5..9938f9ea4c740 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -218,6 +218,9 @@ public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getOutputVariables(), + spatialJoinNode.getProbeGeometryVariable(), + spatialJoinNode.getBuildGeometryVariable(), + spatialJoinNode.getRadiusVariable(), rewritten, spatialJoinNode.getLeftPartitionVariable(), spatialJoinNode.getRightPartitionVariable(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index 81d7eebfedb5e..88bda8083ab12 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -914,6 +914,9 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); - return new SpatialJoinNode(node.getSourceLocation(), node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputVariables()), canonicalize(node.getFilter()), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree()); + return new SpatialJoinNode( + node.getSourceLocation(), + node.getId(), + node.getType(), + left, + right, + canonicalizeAndDistinct(node.getOutputVariables()), + canonicalize(node.getProbeGeometryVariable()), + canonicalize(node.getBuildGeometryVariable()), + canonicalize(node.getRadiusVariable()), + canonicalize(node.getFilter()), + canonicalize(node.getLeftPartitionVariable()), + canonicalize(node.getRightPartitionVariable()), + node.getKdbTree()); } @Override diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index b314bba7484c0..1f22f435a8612 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1284,11 +1284,18 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& tableWriteInfo, const protocol::TaskId& taskId) { auto joinType = toJoinType(node->type); + std::optional radiusVariable = std::nullopt; + if (node->radiusVariable) { + radiusVariable = exprConverter_.toVeloxExpr(*node->radiusVariable); + } return std::make_shared( node->id, joinType, exprConverter_.toVeloxExpr(node->filter), + exprConverter_.toVeloxExpr(node->probeGeometryVariable), + exprConverter_.toVeloxExpr(node->buildGeometryVariable), + radiusVariable, toVeloxQueryPlan(node->left, tableWriteInfo, taskId), toVeloxQueryPlan(node->right, tableWriteInfo, taskId), toRowType(node->outputVariables, typeParser_)); diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp index 8011da82eee47..24e90b78e2e8c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp @@ -370,9 +370,10 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BucketFunctionType_enum_table[] = { // NOLINT: cert-err58-cpp - {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, - {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; + BucketFunctionType_enum_table[] = + { // NOLINT: cert-err58-cpp + {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, + {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; void to_json(json& j, const BucketFunctionType& e) { static_assert( std::is_enum::value, @@ -598,12 +599,13 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - HiveCompressionCodec_enum_table[] = { // NOLINT: cert-err58-cpp - {HiveCompressionCodec::NONE, "NONE"}, - {HiveCompressionCodec::SNAPPY, "SNAPPY"}, - {HiveCompressionCodec::GZIP, "GZIP"}, - {HiveCompressionCodec::LZ4, "LZ4"}, - {HiveCompressionCodec::ZSTD, "ZSTD"}}; + HiveCompressionCodec_enum_table[] = + { // NOLINT: cert-err58-cpp + {HiveCompressionCodec::NONE, "NONE"}, + {HiveCompressionCodec::SNAPPY, "SNAPPY"}, + {HiveCompressionCodec::GZIP, "GZIP"}, + {HiveCompressionCodec::LZ4, "LZ4"}, + {HiveCompressionCodec::ZSTD, "ZSTD"}}; void to_json(json& j, const HiveCompressionCodec& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 3229da2e88d07..6d03a5ce52b12 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -25,11 +25,12 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - ChangelogOperation_enum_table[] = { // NOLINT: cert-err58-cpp - {ChangelogOperation::INSERT, "INSERT"}, - {ChangelogOperation::DELETE, "DELETE"}, - {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, - {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; + ChangelogOperation_enum_table[] = + { // NOLINT: cert-err58-cpp + {ChangelogOperation::INSERT, "INSERT"}, + {ChangelogOperation::DELETE, "DELETE"}, + {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, + {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; void to_json(json& j, const ChangelogOperation& e) { static_assert( std::is_enum::value, @@ -508,14 +509,15 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - PartitionTransformType_enum_table[] = { // NOLINT: cert-err58-cpp - {PartitionTransformType::IDENTITY, "IDENTITY"}, - {PartitionTransformType::YEAR, "YEAR"}, - {PartitionTransformType::MONTH, "MONTH"}, - {PartitionTransformType::DAY, "DAY"}, - {PartitionTransformType::HOUR, "HOUR"}, - {PartitionTransformType::BUCKET, "BUCKET"}, - {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; + PartitionTransformType_enum_table[] = + { // NOLINT: cert-err58-cpp + {PartitionTransformType::IDENTITY, "IDENTITY"}, + {PartitionTransformType::YEAR, "YEAR"}, + {PartitionTransformType::MONTH, "MONTH"}, + {PartitionTransformType::DAY, "DAY"}, + {PartitionTransformType::HOUR, "HOUR"}, + {PartitionTransformType::BUCKET, "BUCKET"}, + {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; void to_json(json& j, const PartitionTransformType& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 50d7de6024075..cc1647256b9e5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -36,10 +36,11 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - NodeSelectionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp - {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, - {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, - {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; + NodeSelectionStrategy_enum_table[] = + { // NOLINT: cert-err58-cpp + {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, + {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, + {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; void to_json(json& j, const NodeSelectionStrategy& e) { static_assert( std::is_enum::value, @@ -558,11 +559,12 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - AggregationNodeStep_enum_table[] = { // NOLINT: cert-err58-cpp - {AggregationNodeStep::PARTIAL, "PARTIAL"}, - {AggregationNodeStep::FINAL, "FINAL"}, - {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, - {AggregationNodeStep::SINGLE, "SINGLE"}}; + AggregationNodeStep_enum_table[] = + { // NOLINT: cert-err58-cpp + {AggregationNodeStep::PARTIAL, "PARTIAL"}, + {AggregationNodeStep::FINAL, "FINAL"}, + {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, + {AggregationNodeStep::SINGLE, "SINGLE"}}; void to_json(json& j, const AggregationNodeStep& e) { static_assert( std::is_enum::value, @@ -2808,10 +2810,11 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BuiltInFunctionKind_enum_table[] = { // NOLINT: cert-err58-cpp - {BuiltInFunctionKind::ENGINE, "ENGINE"}, - {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, - {BuiltInFunctionKind::WORKER, "WORKER"}}; + BuiltInFunctionKind_enum_table[] = + { // NOLINT: cert-err58-cpp + {BuiltInFunctionKind::ENGINE, "ENGINE"}, + {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, + {BuiltInFunctionKind::WORKER, "WORKER"}}; void to_json(json& j, const BuiltInFunctionKind& e) { static_assert( std::is_enum::value, @@ -6144,9 +6147,10 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - JoinDistributionType_enum_table[] = { // NOLINT: cert-err58-cpp - {JoinDistributionType::PARTITIONED, "PARTITIONED"}, - {JoinDistributionType::REPLICATED, "REPLICATED"}}; + JoinDistributionType_enum_table[] = + { // NOLINT: cert-err58-cpp + {JoinDistributionType::PARTITIONED, "PARTITIONED"}, + {JoinDistributionType::REPLICATED, "REPLICATED"}}; void to_json(json& j, const JoinDistributionType& e) { static_assert( std::is_enum::value, @@ -8211,14 +8215,17 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - StageExecutionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp - {StageExecutionStrategy::UNGROUPED_EXECUTION, "UNGROUPED_EXECUTION"}, - {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy::DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, - "RECOVERABLE_GROUPED_EXECUTION"}}; + StageExecutionStrategy_enum_table[] = + { // NOLINT: cert-err58-cpp + {StageExecutionStrategy::UNGROUPED_EXECUTION, + "UNGROUPED_EXECUTION"}, + {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy:: + DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, + "RECOVERABLE_GROUPED_EXECUTION"}}; void to_json(json& j, const StageExecutionStrategy& e) { static_assert( std::is_enum::value, @@ -9606,6 +9613,27 @@ void to_json(json& j, const SpatialJoinNode& p) { "SpatialJoinNode", "List", "outputVariables"); + to_json_key( + j, + "probeGeometryVariable", + p.probeGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "probeGeometryVariable"); + to_json_key( + j, + "buildGeometryVariable", + p.buildGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "buildGeometryVariable"); + to_json_key( + j, + "radiusVariable", + p.radiusVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "radiusVariable"); to_json_key( j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter"); to_json_key( @@ -9639,6 +9667,27 @@ void from_json(const json& j, SpatialJoinNode& p) { "SpatialJoinNode", "List", "outputVariables"); + from_json_key( + j, + "probeGeometryVariable", + p.probeGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "probeGeometryVariable"); + from_json_key( + j, + "buildGeometryVariable", + p.buildGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "buildGeometryVariable"); + from_json_key( + j, + "radiusVariable", + p.radiusVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "radiusVariable"); from_json_key( j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter"); from_json_key( @@ -9888,12 +9937,13 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitionFunction_enum_table[] = { // NOLINT: cert-err58-cpp - {SystemPartitionFunction::SINGLE, "SINGLE"}, - {SystemPartitionFunction::HASH, "HASH"}, - {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, - {SystemPartitionFunction::BROADCAST, "BROADCAST"}, - {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; + SystemPartitionFunction_enum_table[] = + { // NOLINT: cert-err58-cpp + {SystemPartitionFunction::SINGLE, "SINGLE"}, + {SystemPartitionFunction::HASH, "HASH"}, + {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, + {SystemPartitionFunction::BROADCAST, "BROADCAST"}, + {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; void to_json(json& j, const SystemPartitionFunction& e) { static_assert( std::is_enum::value, @@ -9930,13 +9980,14 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitioning_enum_table[] = { // NOLINT: cert-err58-cpp - {SystemPartitioning::SINGLE, "SINGLE"}, - {SystemPartitioning::FIXED, "FIXED"}, - {SystemPartitioning::SOURCE, "SOURCE"}, - {SystemPartitioning::SCALED, "SCALED"}, - {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, - {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; + SystemPartitioning_enum_table[] = + { // NOLINT: cert-err58-cpp + {SystemPartitioning::SINGLE, "SINGLE"}, + {SystemPartitioning::FIXED, "FIXED"}, + {SystemPartitioning::SOURCE, "SOURCE"}, + {SystemPartitioning::SCALED, "SCALED"}, + {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, + {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; void to_json(json& j, const SystemPartitioning& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index dae1c63b907d5..2b1e4eb66c14e 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -69,21 +69,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message){}; + : std::runtime_error(message) {}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message){}; + explicit TypeError(const std::string& message) : Exception(message) {}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message){}; + explicit ParseError(const std::string& message) : Exception(message) {}; }; using String = std::string; @@ -2209,6 +2209,9 @@ struct SpatialJoinNode : public PlanNode { std::shared_ptr left = {}; std::shared_ptr right = {}; List outputVariables = {}; + VariableReferenceExpression probeGeometryVariable = {}; + VariableReferenceExpression buildGeometryVariable = {}; + std::shared_ptr radiusVariable = {}; std::shared_ptr filter = {}; std::shared_ptr leftPartitionVariable = {}; std::shared_ptr rightPartitionVariable = {}; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java index 3c0a6d9e46394..a69c21c87267b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java @@ -68,6 +68,9 @@ public static SpatialJoinType fromJoinNodeType(JoinType joinNodeType) private final PlanNode left; private final PlanNode right; private final List outputVariables; + private final VariableReferenceExpression probeGeometryVariable; + private final VariableReferenceExpression buildGeometryVariable; + private final Optional radiusVariable; private final RowExpression filter; private final Optional leftPartitionVariable; private final Optional rightPartitionVariable; @@ -88,12 +91,29 @@ public SpatialJoinNode( @JsonProperty("left") PlanNode left, @JsonProperty("right") PlanNode right, @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("probeGeometryVariable") VariableReferenceExpression probeGeometryVariable, + @JsonProperty("buildGeometryVariable") VariableReferenceExpression buildGeometryVariable, + @JsonProperty("radiusVariable") Optional radiusVariable, @JsonProperty("filter") RowExpression filter, @JsonProperty("leftPartitionVariable") Optional leftPartitionVariable, @JsonProperty("rightPartitionVariable") Optional rightPartitionVariable, @JsonProperty("kdbTree") Optional kdbTree) { - this(sourceLocation, id, Optional.empty(), type, left, right, outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + this( + sourceLocation, + id, + Optional.empty(), + type, + left, + right, + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } public SpatialJoinNode( @@ -104,6 +124,9 @@ public SpatialJoinNode( PlanNode left, PlanNode right, List outputVariables, + VariableReferenceExpression probeGeometryVariable, + VariableReferenceExpression buildGeometryVariable, + Optional radiusVariable, RowExpression filter, Optional leftPartitionVariable, Optional rightPartitionVariable, @@ -116,10 +139,17 @@ public SpatialJoinNode( this.right = requireNonNull(right, "right is null"); this.outputVariables = unmodifiableList(new ArrayList<>(requireNonNull(outputVariables, "outputVariables is null"))); this.filter = requireNonNull(filter, "filter is null"); + this.probeGeometryVariable = requireNonNull(probeGeometryVariable, "probeGeometryVariable is null"); + this.buildGeometryVariable = requireNonNull(buildGeometryVariable, "buildGeometryVariable is null"); + this.radiusVariable = requireNonNull(radiusVariable, "radiusVariable is null"); this.leftPartitionVariable = requireNonNull(leftPartitionVariable, "leftPartitionVariable is null"); this.rightPartitionVariable = requireNonNull(rightPartitionVariable, "rightPartitionVariable is null"); this.kdbTree = requireNonNull(kdbTree, "kdbTree is null"); + checkArgument(left.getOutputVariables().contains(probeGeometryVariable), "Left join input does not contain probe geometry variable"); + checkArgument(right.getOutputVariables().contains(buildGeometryVariable), "Right join input does not contain build geometry variable"); + radiusVariable.ifPresent(radius -> checkArgument(right.getOutputVariables().contains(radius), "Right join input does not contain radius variable")); + Set inputSymbols = new LinkedHashSet<>(); inputSymbols.addAll(left.getOutputVariables()); inputSymbols.addAll(right.getOutputVariables()); @@ -163,6 +193,24 @@ public RowExpression getFilter() return filter; } + @JsonProperty + public VariableReferenceExpression getProbeGeometryVariable() + { + return probeGeometryVariable; + } + + @JsonProperty + public VariableReferenceExpression getBuildGeometryVariable() + { + return buildGeometryVariable; + } + + @JsonProperty + public Optional getRadiusVariable() + { + return radiusVariable; + } + @JsonProperty public Optional getLeftPartitionVariable() { @@ -213,12 +261,40 @@ public R accept(PlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new SpatialJoinNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), type, newChildren.get(0), newChildren.get(1), outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + return new SpatialJoinNode( + getSourceLocation(), + getId(), + getStatsEquivalentPlanNode(), + type, + newChildren.get(0), + newChildren.get(1), + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new SpatialJoinNode(getSourceLocation(), getId(), statsEquivalentPlanNode, type, left, right, outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + return new SpatialJoinNode( + getSourceLocation(), + getId(), + statsEquivalentPlanNode, + type, + left, + right, + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } }