Skip to content

Commit 3478e08

Browse files
committed
feat: Add SpatialJoinNode to presto_protocol
To send SpatialJoinNodes to Velox, we need to serialize and deserialize them via presto_protocol. This change requires facebookincubator/velox#14339 for spatial joins to not cause an error. After this PR and the above lands, Spatial Joins should be enabled implemented as Nested Loop Joins. Not efficient, but it should be correct.
1 parent 8bf4ab1 commit 3478e08

File tree

6 files changed

+181
-0
lines changed

6 files changed

+181
-0
lines changed

presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,17 @@ core::JoinType toJoinType(protocol::JoinType type) {
11721172

11731173
VELOX_UNSUPPORTED("Unknown join type");
11741174
}
1175+
1176+
core::JoinType toJoinType(protocol::SpatialJoinType type) {
1177+
switch (type) {
1178+
case protocol::SpatialJoinType::INNER:
1179+
return core::JoinType::kInner;
1180+
case protocol::SpatialJoinType::LEFT:
1181+
return core::JoinType::kLeft;
1182+
}
1183+
1184+
VELOX_UNSUPPORTED("Unknown spatial join type");
1185+
}
11751186
} // namespace
11761187

11771188
core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
@@ -1250,6 +1261,23 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
12501261
ROW(std::move(outputNames), std::move(outputTypes)));
12511262
}
12521263

1264+
core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
1265+
const std::shared_ptr<const protocol::SpatialJoinNode>& node,
1266+
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
1267+
const protocol::TaskId& taskId) {
1268+
auto joinType = toJoinType(node->type);
1269+
1270+
std::vector<core::FieldAccessTypedExprPtr> leftKeys;
1271+
std::vector<core::FieldAccessTypedExprPtr> rightKeys;
1272+
1273+
return std::make_shared<core::SpatialJoinNode>(
1274+
node->id,
1275+
joinType,
1276+
exprConverter_.toVeloxExpr(node->filter),
1277+
toVeloxQueryPlan(node->left, tableWriteInfo, taskId),
1278+
toVeloxQueryPlan(node->right, tableWriteInfo, taskId),
1279+
toRowType(node->outputVariables, typeParser_));}
1280+
12531281
std::shared_ptr<const core::IndexLookupJoinNode>
12541282
VeloxQueryPlanConverterBase::toVeloxQueryPlan(
12551283
const std::shared_ptr<const protocol::IndexJoinNode>& node,
@@ -1830,6 +1858,10 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
18301858
std::dynamic_pointer_cast<const protocol::MergeJoinNode>(node)) {
18311859
return toVeloxQueryPlan(join, tableWriteInfo, taskId);
18321860
}
1861+
if (auto spatialJoin =
1862+
std::dynamic_pointer_cast<const protocol::SpatialJoinNode>(node)) {
1863+
return toVeloxQueryPlan(spatialJoin, tableWriteInfo, taskId);
1864+
}
18331865
if (auto remoteSource =
18341866
std::dynamic_pointer_cast<const protocol::RemoteSourceNode>(node)) {
18351867
return toVeloxQueryPlan(remoteSource, tableWriteInfo, taskId);

presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ class VeloxQueryPlanConverterBase {
110110
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
111111
const protocol::TaskId& taskId);
112112

113+
velox::core::PlanNodePtr toVeloxQueryPlan(
114+
const std::shared_ptr<const protocol::SpatialJoinNode>& node,
115+
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,
116+
const protocol::TaskId& taskId);
117+
113118
std::shared_ptr<const velox::core::IndexLookupJoinNode> toVeloxQueryPlan(
114119
const std::shared_ptr<const protocol::IndexJoinNode>& node,
115120
const std::shared_ptr<protocol::TableWriteInfo>& tableWriteInfo,

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,10 @@ void to_json(json& j, const std::shared_ptr<PlanNode>& p) {
728728
j = *std::static_pointer_cast<SemiJoinNode>(p);
729729
return;
730730
}
731+
if (type == ".SpatialJoinNode") {
732+
j = *std::static_pointer_cast<SpatialJoinNode>(p);
733+
return;
734+
}
731735
if (type == ".TableScanNode") {
732736
j = *std::static_pointer_cast<TableScanNode>(p);
733737
return;
@@ -896,6 +900,12 @@ void from_json(const json& j, std::shared_ptr<PlanNode>& p) {
896900
p = std::static_pointer_cast<PlanNode>(k);
897901
return;
898902
}
903+
if (type == ".SpatialJoinNode") {
904+
std::shared_ptr<SpatialJoinNode> k = std::make_shared<SpatialJoinNode>();
905+
j.get_to(*k);
906+
p = std::static_pointer_cast<PlanNode>(k);
907+
return;
908+
}
899909
if (type == ".TableScanNode") {
900910
std::shared_ptr<TableScanNode> k = std::make_shared<TableScanNode>();
901911
j.get_to(*k);
@@ -9343,6 +9353,115 @@ void from_json(const json& j, SortedRangeSet& p) {
93439353
namespace facebook::presto::protocol {
93449354
// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM()
93459355

9356+
// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays
9357+
static const std::pair<SpatialJoinType, json> SpatialJoinType_enum_table[] =
9358+
{ // NOLINT: cert-err58-cpp
9359+
{SpatialJoinType::INNER, "INNER"},
9360+
{SpatialJoinType::LEFT, "LEFT"}};
9361+
void to_json(json& j, const SpatialJoinType& e) {
9362+
static_assert(
9363+
std::is_enum<SpatialJoinType>::value, "SpatialJoinType must be an enum!");
9364+
const auto* it = std::find_if(
9365+
std::begin(SpatialJoinType_enum_table),
9366+
std::end(SpatialJoinType_enum_table),
9367+
[e](const std::pair<SpatialJoinType, json>& ej_pair) -> bool {
9368+
return ej_pair.first == e;
9369+
});
9370+
j = ((it != std::end(SpatialJoinType_enum_table))
9371+
? it
9372+
: std::begin(SpatialJoinType_enum_table))
9373+
->second;
9374+
}
9375+
void from_json(const json& j, SpatialJoinType& e) {
9376+
static_assert(
9377+
std::is_enum<SpatialJoinType>::value, "SpatialJoinType must be an enum!");
9378+
const auto* it = std::find_if(
9379+
std::begin(SpatialJoinType_enum_table),
9380+
std::end(SpatialJoinType_enum_table),
9381+
[&j](const std::pair<SpatialJoinType, json>& ej_pair) -> bool {
9382+
return ej_pair.second == j;
9383+
});
9384+
e = ((it != std::end(SpatialJoinType_enum_table))
9385+
? it
9386+
: std::begin(SpatialJoinType_enum_table))
9387+
->first;
9388+
}
9389+
} // namespace facebook::presto::protocol
9390+
namespace facebook::presto::protocol {
9391+
SpatialJoinNode::SpatialJoinNode() noexcept {
9392+
_type = ".SpatialJoinNode";
9393+
}
9394+
9395+
void to_json(json& j, const SpatialJoinNode& p) {
9396+
j = json::object();
9397+
j["@type"] = ".SpatialJoinNode";
9398+
to_json_key(j, "id", p.id, "SpatialJoinNode", "PlanNodeId", "id");
9399+
to_json_key(j, "type", p.type, "SpatialJoinNode", "SpatialJoinType", "type");
9400+
to_json_key(j, "left", p.left, "SpatialJoinNode", "PlanNode", "left");
9401+
to_json_key(j, "right", p.right, "SpatialJoinNode", "PlanNode", "right");
9402+
to_json_key(
9403+
j,
9404+
"outputVariables",
9405+
p.outputVariables,
9406+
"SpatialJoinNode",
9407+
"List<VariableReferenceExpression>",
9408+
"outputVariables");
9409+
to_json_key(
9410+
j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter");
9411+
to_json_key(
9412+
j,
9413+
"leftPartitionVariable",
9414+
p.leftPartitionVariable,
9415+
"SpatialJoinNode",
9416+
"VariableReferenceExpression",
9417+
"leftPartitionVariable");
9418+
to_json_key(
9419+
j,
9420+
"rightPartitionVariable",
9421+
p.rightPartitionVariable,
9422+
"SpatialJoinNode",
9423+
"VariableReferenceExpression",
9424+
"rightPartitionVariable");
9425+
to_json_key(j, "kdbTree", p.kdbTree, "SpatialJoinNode", "String", "kdbTree");
9426+
}
9427+
9428+
void from_json(const json& j, SpatialJoinNode& p) {
9429+
p._type = j["@type"];
9430+
from_json_key(j, "id", p.id, "SpatialJoinNode", "PlanNodeId", "id");
9431+
from_json_key(
9432+
j, "type", p.type, "SpatialJoinNode", "SpatialJoinType", "type");
9433+
from_json_key(j, "left", p.left, "SpatialJoinNode", "PlanNode", "left");
9434+
from_json_key(j, "right", p.right, "SpatialJoinNode", "PlanNode", "right");
9435+
from_json_key(
9436+
j,
9437+
"outputVariables",
9438+
p.outputVariables,
9439+
"SpatialJoinNode",
9440+
"List<VariableReferenceExpression>",
9441+
"outputVariables");
9442+
from_json_key(
9443+
j, "filter", p.filter, "SpatialJoinNode", "RowExpression", "filter");
9444+
from_json_key(
9445+
j,
9446+
"leftPartitionVariable",
9447+
p.leftPartitionVariable,
9448+
"SpatialJoinNode",
9449+
"VariableReferenceExpression",
9450+
"leftPartitionVariable");
9451+
from_json_key(
9452+
j,
9453+
"rightPartitionVariable",
9454+
p.rightPartitionVariable,
9455+
"SpatialJoinNode",
9456+
"VariableReferenceExpression",
9457+
"rightPartitionVariable");
9458+
from_json_key(
9459+
j, "kdbTree", p.kdbTree, "SpatialJoinNode", "String", "kdbTree");
9460+
}
9461+
} // namespace facebook::presto::protocol
9462+
namespace facebook::presto::protocol {
9463+
// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM()
9464+
93469465
// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays
93479466
static const std::pair<Form, json> Form_enum_table[] =
93489467
{ // NOLINT: cert-err58-cpp

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,27 @@ void to_json(json& j, const SortedRangeSet& p);
21642164
void from_json(const json& j, SortedRangeSet& p);
21652165
} // namespace facebook::presto::protocol
21662166
namespace facebook::presto::protocol {
2167+
enum class SpatialJoinType { INNER, LEFT };
2168+
extern void to_json(json& j, const SpatialJoinType& e);
2169+
extern void from_json(const json& j, SpatialJoinType& e);
2170+
} // namespace facebook::presto::protocol
2171+
namespace facebook::presto::protocol {
2172+
struct SpatialJoinNode : public PlanNode {
2173+
SpatialJoinType type = {};
2174+
std::shared_ptr<PlanNode> left = {};
2175+
std::shared_ptr<PlanNode> right = {};
2176+
List<VariableReferenceExpression> outputVariables = {};
2177+
std::shared_ptr<RowExpression> filter = {};
2178+
std::shared_ptr<VariableReferenceExpression> leftPartitionVariable = {};
2179+
std::shared_ptr<VariableReferenceExpression> rightPartitionVariable = {};
2180+
std::shared_ptr<String> kdbTree = {};
2181+
2182+
SpatialJoinNode() noexcept;
2183+
};
2184+
void to_json(json& j, const SpatialJoinNode& p);
2185+
void from_json(const json& j, SpatialJoinNode& p);
2186+
} // namespace facebook::presto::protocol
2187+
namespace facebook::presto::protocol {
21672188
enum class Form {
21682189
IF,
21692190
NULL_IF,

presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ AbstractClasses:
160160
- { name: RemoteSourceNode, key: com.facebook.presto.sql.planner.plan.RemoteSourceNode }
161161
- { name: SampleNode, key: com.facebook.presto.sql.planner.plan.SampleNode }
162162
- { name: SemiJoinNode, key: .SemiJoinNode }
163+
- { name: SpatialJoinNode, key: .SpatialJoinNode }
163164
- { name: TableScanNode, key: .TableScanNode }
164165
- { name: TableWriterNode, key: .TableWriterNode }
165166
- { name: TableWriterMergeNode, key: com.facebook.presto.sql.planner.plan.TableWriterMergeNode }
@@ -320,6 +321,7 @@ JavaClasses:
320321
- presto-spi/src/main/java/com/facebook/presto/spi/plan/JoinNode.java
321322
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SemiJoinNode.java
322323
- presto-spi/src/main/java/com/facebook/presto/spi/plan/MergeJoinNode.java
324+
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java
323325
- presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java
324326
- presto-spi/src/main/java/com/facebook/presto/spi/plan/IndexSourceNode.java
325327
- presto-spi/src/main/java/com/facebook/presto/spi/plan/TopNNode.java

presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ AbstractClasses:
155155
- { name: RemoteSourceNode, key: com.facebook.presto.sql.planner.plan.RemoteSourceNode }
156156
- { name: SampleNode, key: com.facebook.presto.sql.planner.plan.SampleNode }
157157
- { name: SemiJoinNode, key: .SemiJoinNode }
158+
- { name: SpatialJoinNode, key: .SpatialJoinNode }
158159
- { name: TableScanNode, key: .TableScanNode }
159160
- { name: TableWriterNode, key: .TableWriterNode }
160161
- { name: TableWriterMergeNode, key: com.facebook.presto.sql.planner.plan.TableWriterMergeNode }
@@ -360,6 +361,7 @@ JavaClasses:
360361
- presto-spi/src/main/java/com/facebook/presto/spi/plan/JoinNode.java
361362
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SemiJoinNode.java
362363
- presto-spi/src/main/java/com/facebook/presto/spi/plan/MergeJoinNode.java
364+
- presto-spi/src/main/java/com/facebook/presto/spi/plan/SpatialJoinNode.java
363365
- presto-spi/src/main/java/com/facebook/presto/spi/plan/TopNNode.java
364366
- presto-hive/src/main/java/com/facebook/presto/hive/HivePartitioningHandle.java
365367
- presto-main/src/main/java/com/facebook/presto/split/EmptySplit.java

0 commit comments

Comments
 (0)