From 6187e3f817e0646ac7c24022b7a53e302a5080d5 Mon Sep 17 00:00:00 2001 From: James Gill Date: Mon, 6 Oct 2025 13:51:45 -0400 Subject: [PATCH] feat(geo): Add geometry and radius variables to SpatialJoinNode In Java, the optimizer's ExtractSpatialJoins as well as the LocalExecutionPlanner must find build and probe expressions that contain geometries for the join to be planned or executed. In ExtractSpatialJoins, this work was then thrown away. Velox also needs to do this when constructing its SpatialJoinBuild/Probe. Instead of adding planner logic in Velox, this extracts the already calculated variables for probe and build geometries and optionally a radius and adds them to SpatialJoinNode. This will allow us to pass them to Velox. Also update presto cpp protocol to reflect these new variables. --- .../iterative/rule/ExtractSpatialJoins.java | 11 ++ .../rule/RowExpressionRewriteRuleSet.java | 3 + .../optimizations/PredicatePushDown.java | 6 + .../PruneUnreferencedOutputs.java | 16 ++- .../UnaliasSymbolReferences.java | 15 ++- .../main/types/PrestoToVeloxQueryPlan.cpp | 7 + .../connector/hive/presto_protocol_hive.cpp | 20 +-- .../iceberg/presto_protocol_iceberg.cpp | 28 ++-- .../core/presto_protocol_core.cpp | 125 ++++++++++++------ .../core/presto_protocol_core.h | 11 +- .../presto/spi/plan/SpatialJoinNode.java | 82 +++++++++++- 11 files changed, 256 insertions(+), 68 deletions(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index 93fa95932846e..db28daede971c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -460,6 +460,8 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { // with a projection that adds the argument as a variable. Optional newFirstVariable = newVariable(context, firstArgument); Optional newSecondVariable = newVariable(context, secondArgument); + VariableReferenceExpression leftGeometryVariable; + VariableReferenceExpression rightGeometryVariable; PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -470,10 +472,16 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { if (firstArgumentOnLeft) { newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); + rightGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); } else { newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); + rightGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); } RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument); @@ -512,6 +520,9 @@ else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { newLeftNode, newRightNode, outputVariables, + leftGeometryVariable, + rightGeometryVariable, + radius, newFilter, leftPartitionVariable, rightPartitionVariable, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index e969326ecfff5..9938f9ea4c740 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -218,6 +218,9 @@ public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getOutputVariables(), + spatialJoinNode.getProbeGeometryVariable(), + spatialJoinNode.getBuildGeometryVariable(), + spatialJoinNode.getRadiusVariable(), rewritten, spatialJoinNode.getLeftPartitionVariable(), spatialJoinNode.getRightPartitionVariable(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index 81d7eebfedb5e..88bda8083ab12 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -914,6 +914,9 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); - return new SpatialJoinNode(node.getSourceLocation(), node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputVariables()), canonicalize(node.getFilter()), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree()); + return new SpatialJoinNode( + node.getSourceLocation(), + node.getId(), + node.getType(), + left, + right, + canonicalizeAndDistinct(node.getOutputVariables()), + canonicalize(node.getProbeGeometryVariable()), + canonicalize(node.getBuildGeometryVariable()), + canonicalize(node.getRadiusVariable()), + canonicalize(node.getFilter()), + canonicalize(node.getLeftPartitionVariable()), + canonicalize(node.getRightPartitionVariable()), + node.getKdbTree()); } @Override diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index b314bba7484c0..1f22f435a8612 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1284,11 +1284,18 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& tableWriteInfo, const protocol::TaskId& taskId) { auto joinType = toJoinType(node->type); + std::optional radiusVariable = std::nullopt; + if (node->radiusVariable) { + radiusVariable = exprConverter_.toVeloxExpr(*node->radiusVariable); + } return std::make_shared( node->id, joinType, exprConverter_.toVeloxExpr(node->filter), + exprConverter_.toVeloxExpr(node->probeGeometryVariable), + exprConverter_.toVeloxExpr(node->buildGeometryVariable), + radiusVariable, toVeloxQueryPlan(node->left, tableWriteInfo, taskId), toVeloxQueryPlan(node->right, tableWriteInfo, taskId), toRowType(node->outputVariables, typeParser_)); diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp index 8011da82eee47..24e90b78e2e8c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp @@ -370,9 +370,10 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BucketFunctionType_enum_table[] = { // NOLINT: cert-err58-cpp - {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, - {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; + BucketFunctionType_enum_table[] = + { // NOLINT: cert-err58-cpp + {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, + {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; void to_json(json& j, const BucketFunctionType& e) { static_assert( std::is_enum::value, @@ -598,12 +599,13 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - HiveCompressionCodec_enum_table[] = { // NOLINT: cert-err58-cpp - {HiveCompressionCodec::NONE, "NONE"}, - {HiveCompressionCodec::SNAPPY, "SNAPPY"}, - {HiveCompressionCodec::GZIP, "GZIP"}, - {HiveCompressionCodec::LZ4, "LZ4"}, - {HiveCompressionCodec::ZSTD, "ZSTD"}}; + HiveCompressionCodec_enum_table[] = + { // NOLINT: cert-err58-cpp + {HiveCompressionCodec::NONE, "NONE"}, + {HiveCompressionCodec::SNAPPY, "SNAPPY"}, + {HiveCompressionCodec::GZIP, "GZIP"}, + {HiveCompressionCodec::LZ4, "LZ4"}, + {HiveCompressionCodec::ZSTD, "ZSTD"}}; void to_json(json& j, const HiveCompressionCodec& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 3229da2e88d07..6d03a5ce52b12 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -25,11 +25,12 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - ChangelogOperation_enum_table[] = { // NOLINT: cert-err58-cpp - {ChangelogOperation::INSERT, "INSERT"}, - {ChangelogOperation::DELETE, "DELETE"}, - {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, - {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; + ChangelogOperation_enum_table[] = + { // NOLINT: cert-err58-cpp + {ChangelogOperation::INSERT, "INSERT"}, + {ChangelogOperation::DELETE, "DELETE"}, + {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, + {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; void to_json(json& j, const ChangelogOperation& e) { static_assert( std::is_enum::value, @@ -508,14 +509,15 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - PartitionTransformType_enum_table[] = { // NOLINT: cert-err58-cpp - {PartitionTransformType::IDENTITY, "IDENTITY"}, - {PartitionTransformType::YEAR, "YEAR"}, - {PartitionTransformType::MONTH, "MONTH"}, - {PartitionTransformType::DAY, "DAY"}, - {PartitionTransformType::HOUR, "HOUR"}, - {PartitionTransformType::BUCKET, "BUCKET"}, - {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; + PartitionTransformType_enum_table[] = + { // NOLINT: cert-err58-cpp + {PartitionTransformType::IDENTITY, "IDENTITY"}, + {PartitionTransformType::YEAR, "YEAR"}, + {PartitionTransformType::MONTH, "MONTH"}, + {PartitionTransformType::DAY, "DAY"}, + {PartitionTransformType::HOUR, "HOUR"}, + {PartitionTransformType::BUCKET, "BUCKET"}, + {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; void to_json(json& j, const PartitionTransformType& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 50d7de6024075..cc1647256b9e5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -36,10 +36,11 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - NodeSelectionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp - {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, - {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, - {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; + NodeSelectionStrategy_enum_table[] = + { // NOLINT: cert-err58-cpp + {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, + {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, + {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; void to_json(json& j, const NodeSelectionStrategy& e) { static_assert( std::is_enum::value, @@ -558,11 +559,12 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - AggregationNodeStep_enum_table[] = { // NOLINT: cert-err58-cpp - {AggregationNodeStep::PARTIAL, "PARTIAL"}, - {AggregationNodeStep::FINAL, "FINAL"}, - {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, - {AggregationNodeStep::SINGLE, "SINGLE"}}; + AggregationNodeStep_enum_table[] = + { // NOLINT: cert-err58-cpp + {AggregationNodeStep::PARTIAL, "PARTIAL"}, + {AggregationNodeStep::FINAL, "FINAL"}, + {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, + {AggregationNodeStep::SINGLE, "SINGLE"}}; void to_json(json& j, const AggregationNodeStep& e) { static_assert( std::is_enum::value, @@ -2808,10 +2810,11 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BuiltInFunctionKind_enum_table[] = { // NOLINT: cert-err58-cpp - {BuiltInFunctionKind::ENGINE, "ENGINE"}, - {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, - {BuiltInFunctionKind::WORKER, "WORKER"}}; + BuiltInFunctionKind_enum_table[] = + { // NOLINT: cert-err58-cpp + {BuiltInFunctionKind::ENGINE, "ENGINE"}, + {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, + {BuiltInFunctionKind::WORKER, "WORKER"}}; void to_json(json& j, const BuiltInFunctionKind& e) { static_assert( std::is_enum::value, @@ -6144,9 +6147,10 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - JoinDistributionType_enum_table[] = { // NOLINT: cert-err58-cpp - {JoinDistributionType::PARTITIONED, "PARTITIONED"}, - {JoinDistributionType::REPLICATED, "REPLICATED"}}; + JoinDistributionType_enum_table[] = + { // NOLINT: cert-err58-cpp + {JoinDistributionType::PARTITIONED, "PARTITIONED"}, + {JoinDistributionType::REPLICATED, "REPLICATED"}}; void to_json(json& j, const JoinDistributionType& e) { static_assert( std::is_enum::value, @@ -8211,14 +8215,17 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - StageExecutionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp - {StageExecutionStrategy::UNGROUPED_EXECUTION, "UNGROUPED_EXECUTION"}, - {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy::DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, - "RECOVERABLE_GROUPED_EXECUTION"}}; + StageExecutionStrategy_enum_table[] = + { // NOLINT: cert-err58-cpp + {StageExecutionStrategy::UNGROUPED_EXECUTION, + "UNGROUPED_EXECUTION"}, + {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy:: + DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, + "RECOVERABLE_GROUPED_EXECUTION"}}; void to_json(json& j, const StageExecutionStrategy& e) { static_assert( std::is_enum::value, @@ -9606,6 +9613,27 @@ void to_json(json& j, const SpatialJoinNode& p) { "SpatialJoinNode", "List", "outputVariables"); + to_json_key( + j, + "probeGeometryVariable", + p.probeGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "probeGeometryVariable"); + to_json_key( + j, + "buildGeometryVariable", + p.buildGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "buildGeometryVariable"); + to_json_key( + j, + "radiusVariable", + p.radiusVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "radiusVariable"); to_json_key( j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter"); to_json_key( @@ -9639,6 +9667,27 @@ void from_json(const json& j, SpatialJoinNode& p) { "SpatialJoinNode", "List", "outputVariables"); + from_json_key( + j, + "probeGeometryVariable", + p.probeGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "probeGeometryVariable"); + from_json_key( + j, + "buildGeometryVariable", + p.buildGeometryVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "buildGeometryVariable"); + from_json_key( + j, + "radiusVariable", + p.radiusVariable, + "SpatialJoinNode", + "VariableReferenceExpression", + "radiusVariable"); from_json_key( j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter"); from_json_key( @@ -9888,12 +9937,13 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitionFunction_enum_table[] = { // NOLINT: cert-err58-cpp - {SystemPartitionFunction::SINGLE, "SINGLE"}, - {SystemPartitionFunction::HASH, "HASH"}, - {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, - {SystemPartitionFunction::BROADCAST, "BROADCAST"}, - {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; + SystemPartitionFunction_enum_table[] = + { // NOLINT: cert-err58-cpp + {SystemPartitionFunction::SINGLE, "SINGLE"}, + {SystemPartitionFunction::HASH, "HASH"}, + {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, + {SystemPartitionFunction::BROADCAST, "BROADCAST"}, + {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; void to_json(json& j, const SystemPartitionFunction& e) { static_assert( std::is_enum::value, @@ -9930,13 +9980,14 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitioning_enum_table[] = { // NOLINT: cert-err58-cpp - {SystemPartitioning::SINGLE, "SINGLE"}, - {SystemPartitioning::FIXED, "FIXED"}, - {SystemPartitioning::SOURCE, "SOURCE"}, - {SystemPartitioning::SCALED, "SCALED"}, - {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, - {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; + SystemPartitioning_enum_table[] = + { // NOLINT: cert-err58-cpp + {SystemPartitioning::SINGLE, "SINGLE"}, + {SystemPartitioning::FIXED, "FIXED"}, + {SystemPartitioning::SOURCE, "SOURCE"}, + {SystemPartitioning::SCALED, "SCALED"}, + {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, + {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; void to_json(json& j, const SystemPartitioning& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index dae1c63b907d5..2b1e4eb66c14e 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -69,21 +69,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message){}; + : std::runtime_error(message) {}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message){}; + explicit TypeError(const std::string& message) : Exception(message) {}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message){}; + explicit ParseError(const std::string& message) : Exception(message) {}; }; using String = std::string; @@ -2209,6 +2209,9 @@ struct SpatialJoinNode : public PlanNode { std::shared_ptr left = {}; std::shared_ptr right = {}; List outputVariables = {}; + VariableReferenceExpression probeGeometryVariable = {}; + VariableReferenceExpression buildGeometryVariable = {}; + std::shared_ptr radiusVariable = {}; std::shared_ptr filter = {}; std::shared_ptr leftPartitionVariable = {}; std::shared_ptr rightPartitionVariable = {}; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java index 3c0a6d9e46394..a69c21c87267b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java @@ -68,6 +68,9 @@ public static SpatialJoinType fromJoinNodeType(JoinType joinNodeType) private final PlanNode left; private final PlanNode right; private final List outputVariables; + private final VariableReferenceExpression probeGeometryVariable; + private final VariableReferenceExpression buildGeometryVariable; + private final Optional radiusVariable; private final RowExpression filter; private final Optional leftPartitionVariable; private final Optional rightPartitionVariable; @@ -88,12 +91,29 @@ public SpatialJoinNode( @JsonProperty("left") PlanNode left, @JsonProperty("right") PlanNode right, @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("probeGeometryVariable") VariableReferenceExpression probeGeometryVariable, + @JsonProperty("buildGeometryVariable") VariableReferenceExpression buildGeometryVariable, + @JsonProperty("radiusVariable") Optional radiusVariable, @JsonProperty("filter") RowExpression filter, @JsonProperty("leftPartitionVariable") Optional leftPartitionVariable, @JsonProperty("rightPartitionVariable") Optional rightPartitionVariable, @JsonProperty("kdbTree") Optional kdbTree) { - this(sourceLocation, id, Optional.empty(), type, left, right, outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + this( + sourceLocation, + id, + Optional.empty(), + type, + left, + right, + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } public SpatialJoinNode( @@ -104,6 +124,9 @@ public SpatialJoinNode( PlanNode left, PlanNode right, List outputVariables, + VariableReferenceExpression probeGeometryVariable, + VariableReferenceExpression buildGeometryVariable, + Optional radiusVariable, RowExpression filter, Optional leftPartitionVariable, Optional rightPartitionVariable, @@ -116,10 +139,17 @@ public SpatialJoinNode( this.right = requireNonNull(right, "right is null"); this.outputVariables = unmodifiableList(new ArrayList<>(requireNonNull(outputVariables, "outputVariables is null"))); this.filter = requireNonNull(filter, "filter is null"); + this.probeGeometryVariable = requireNonNull(probeGeometryVariable, "probeGeometryVariable is null"); + this.buildGeometryVariable = requireNonNull(buildGeometryVariable, "buildGeometryVariable is null"); + this.radiusVariable = requireNonNull(radiusVariable, "radiusVariable is null"); this.leftPartitionVariable = requireNonNull(leftPartitionVariable, "leftPartitionVariable is null"); this.rightPartitionVariable = requireNonNull(rightPartitionVariable, "rightPartitionVariable is null"); this.kdbTree = requireNonNull(kdbTree, "kdbTree is null"); + checkArgument(left.getOutputVariables().contains(probeGeometryVariable), "Left join input does not contain probe geometry variable"); + checkArgument(right.getOutputVariables().contains(buildGeometryVariable), "Right join input does not contain build geometry variable"); + radiusVariable.ifPresent(radius -> checkArgument(right.getOutputVariables().contains(radius), "Right join input does not contain radius variable")); + Set inputSymbols = new LinkedHashSet<>(); inputSymbols.addAll(left.getOutputVariables()); inputSymbols.addAll(right.getOutputVariables()); @@ -163,6 +193,24 @@ public RowExpression getFilter() return filter; } + @JsonProperty + public VariableReferenceExpression getProbeGeometryVariable() + { + return probeGeometryVariable; + } + + @JsonProperty + public VariableReferenceExpression getBuildGeometryVariable() + { + return buildGeometryVariable; + } + + @JsonProperty + public Optional getRadiusVariable() + { + return radiusVariable; + } + @JsonProperty public Optional getLeftPartitionVariable() { @@ -213,12 +261,40 @@ public R accept(PlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new SpatialJoinNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), type, newChildren.get(0), newChildren.get(1), outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + return new SpatialJoinNode( + getSourceLocation(), + getId(), + getStatsEquivalentPlanNode(), + type, + newChildren.get(0), + newChildren.get(1), + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new SpatialJoinNode(getSourceLocation(), getId(), statsEquivalentPlanNode, type, left, right, outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); + return new SpatialJoinNode( + getSourceLocation(), + getId(), + statsEquivalentPlanNode, + type, + left, + right, + outputVariables, + probeGeometryVariable, + buildGeometryVariable, + radiusVariable, + filter, + leftPartitionVariable, + rightPartitionVariable, + kdbTree); } }