Skip to content

Commit 00386c9

Browse files
authored
feat: Flotilla broadcast join (#4867)
## Changes Made Add broadcast joins to flotilla. The translation logic for determining join strategy + broadcast join threshold is copied from the old ray runner's physical plan translator. The broadcast join works by simply materializing the partition refs for the broadcast side, then copying them to the each of the receiver sides task (aka broadcast), and then adding a hash join instruction. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 40e3818 commit 00386c9

File tree

10 files changed

+539
-122
lines changed

10 files changed

+539
-122
lines changed

src/daft-distributed/src/pipeline_node/gather.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use futures::TryStreamExt;
99
use super::{DistributedPipelineNode, SubmittableTaskStream};
1010
use crate::{
1111
pipeline_node::{
12-
make_in_memory_scan_from_materialized_outputs, NodeID, NodeName, PipelineNodeConfig,
12+
make_in_memory_task_from_materialized_outputs, NodeID, NodeName, PipelineNodeConfig,
1313
PipelineNodeContext,
1414
},
1515
scheduling::{
@@ -80,7 +80,7 @@ impl GatherNode {
8080
.await?;
8181

8282
let self_clone = self.clone();
83-
let task = make_in_memory_scan_from_materialized_outputs(
83+
let task = make_in_memory_task_from_materialized_outputs(
8484
TaskContext::from((&self_clone.context, task_id_counter.next())),
8585
materialized,
8686
&(self_clone as Arc<dyn DistributedPipelineNode>),
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
use std::{collections::HashMap, sync::Arc};
2+
3+
use common_display::{tree::TreeDisplay, DisplayLevel};
4+
use common_error::DaftResult;
5+
use daft_dsl::expr::bound_expr::BoundExpr;
6+
use daft_local_plan::LocalPhysicalPlan;
7+
use daft_logical_plan::{stats::StatsState, JoinType};
8+
use daft_schema::schema::SchemaRef;
9+
use futures::{StreamExt, TryStreamExt};
10+
11+
use crate::{
12+
pipeline_node::{
13+
make_in_memory_scan_from_materialized_outputs, DistributedPipelineNode, NodeID, NodeName,
14+
PipelineNodeConfig, PipelineNodeContext, SubmittableTaskStream,
15+
},
16+
scheduling::{
17+
scheduler::{SchedulerHandle, SubmittableTask},
18+
task::{SchedulingStrategy, SwordfishTask, TaskContext},
19+
},
20+
stage::{StageConfig, StageExecutionContext, TaskIDCounter},
21+
utils::channel::{create_channel, Sender},
22+
};
23+
24+
pub(crate) struct BroadcastJoinNode {
25+
config: PipelineNodeConfig,
26+
context: PipelineNodeContext,
27+
28+
// Join properties
29+
left_on: Vec<BoundExpr>,
30+
right_on: Vec<BoundExpr>,
31+
null_equals_nulls: Option<Vec<bool>>,
32+
join_type: JoinType,
33+
is_swapped: bool,
34+
35+
broadcaster: Arc<dyn DistributedPipelineNode>,
36+
broadcaster_schema: SchemaRef,
37+
receiver: Arc<dyn DistributedPipelineNode>,
38+
}
39+
40+
impl BroadcastJoinNode {
41+
const NODE_NAME: NodeName = "BroadcastJoin";
42+
43+
#[allow(clippy::too_many_arguments)]
44+
pub fn new(
45+
node_id: NodeID,
46+
logical_node_id: Option<NodeID>,
47+
stage_config: &StageConfig,
48+
left_on: Vec<BoundExpr>,
49+
right_on: Vec<BoundExpr>,
50+
null_equals_nulls: Option<Vec<bool>>,
51+
join_type: JoinType,
52+
is_swapped: bool,
53+
broadcaster: Arc<dyn DistributedPipelineNode>,
54+
receiver: Arc<dyn DistributedPipelineNode>,
55+
output_schema: SchemaRef,
56+
) -> Self {
57+
let context = PipelineNodeContext::new(
58+
stage_config,
59+
node_id,
60+
Self::NODE_NAME,
61+
vec![broadcaster.node_id(), receiver.node_id()],
62+
vec![broadcaster.name(), receiver.name()],
63+
logical_node_id,
64+
);
65+
66+
// For broadcast joins, we use the receiver's clustering spec since the broadcaster
67+
// will be gathered to all partitions
68+
let config = PipelineNodeConfig::new(
69+
output_schema,
70+
stage_config.config.clone(),
71+
receiver.config().clustering_spec.clone(),
72+
);
73+
74+
let broadcaster_schema = broadcaster.config().schema.clone();
75+
Self {
76+
config,
77+
context,
78+
left_on,
79+
right_on,
80+
null_equals_nulls,
81+
join_type,
82+
is_swapped,
83+
broadcaster,
84+
broadcaster_schema,
85+
receiver,
86+
}
87+
}
88+
89+
pub fn arced(self) -> Arc<dyn DistributedPipelineNode> {
90+
Arc::new(self)
91+
}
92+
93+
fn multiline_display(&self) -> Vec<String> {
94+
use itertools::Itertools;
95+
let mut res = vec!["Broadcast Join".to_string()];
96+
res.push(format!(
97+
"Left on: {}",
98+
self.left_on.iter().map(|e| e.to_string()).join(", ")
99+
));
100+
res.push(format!(
101+
"Right on: {}",
102+
self.right_on.iter().map(|e| e.to_string()).join(", ")
103+
));
104+
res.push(format!("Join type: {}", self.join_type));
105+
res.push(format!("Is swapped: {}", self.is_swapped));
106+
if let Some(null_equals_nulls) = &self.null_equals_nulls {
107+
res.push(format!(
108+
"Null equals nulls: [{}]",
109+
null_equals_nulls.iter().map(|b| b.to_string()).join(", ")
110+
));
111+
}
112+
res
113+
}
114+
115+
async fn execution_loop(
116+
self: Arc<Self>,
117+
broadcaster_input: SubmittableTaskStream,
118+
mut receiver_input: SubmittableTaskStream,
119+
task_id_counter: TaskIDCounter,
120+
result_tx: Sender<SubmittableTask<SwordfishTask>>,
121+
scheduler_handle: SchedulerHandle<SwordfishTask>,
122+
) -> DaftResult<()> {
123+
let materialized_broadcast_data = broadcaster_input
124+
.materialize(scheduler_handle.clone())
125+
.try_collect::<Vec<_>>()
126+
.await?;
127+
let materialized_broadcast_data_plan = make_in_memory_scan_from_materialized_outputs(
128+
&materialized_broadcast_data,
129+
self.broadcaster_schema.clone(),
130+
self.node_id(),
131+
)?;
132+
let broadcast_psets = HashMap::from([(
133+
self.node_id().to_string(),
134+
materialized_broadcast_data
135+
.into_iter()
136+
.flat_map(|output| output.into_inner().0)
137+
.collect::<Vec<_>>(),
138+
)]);
139+
while let Some(task) = receiver_input.next().await {
140+
let input_plan = task.task().plan();
141+
let (left_plan, right_plan) = if self.is_swapped {
142+
(input_plan, materialized_broadcast_data_plan.clone())
143+
} else {
144+
(materialized_broadcast_data_plan.clone(), input_plan)
145+
};
146+
let join_plan = LocalPhysicalPlan::hash_join(
147+
left_plan,
148+
right_plan,
149+
self.left_on.clone(),
150+
self.right_on.clone(),
151+
self.null_equals_nulls.clone(),
152+
self.join_type,
153+
self.config.schema.clone(),
154+
StatsState::NotMaterialized,
155+
);
156+
157+
let mut psets = task.task().psets().clone();
158+
psets.extend(broadcast_psets.clone());
159+
160+
let config = task.task().config().clone();
161+
162+
let task = task.with_new_task(SwordfishTask::new(
163+
TaskContext::from((self.context(), task_id_counter.next())),
164+
join_plan,
165+
config,
166+
psets,
167+
SchedulingStrategy::Spread,
168+
self.context().to_hashmap(),
169+
));
170+
if result_tx.send(task).await.is_err() {
171+
break;
172+
}
173+
}
174+
Ok(())
175+
}
176+
}
177+
178+
impl TreeDisplay for BroadcastJoinNode {
179+
fn display_as(&self, level: DisplayLevel) -> String {
180+
match level {
181+
DisplayLevel::Compact => self.get_name(),
182+
_ => self.multiline_display().join("\n"),
183+
}
184+
}
185+
186+
fn get_children(&self) -> Vec<&dyn TreeDisplay> {
187+
vec![
188+
self.broadcaster.as_tree_display(),
189+
self.receiver.as_tree_display(),
190+
]
191+
}
192+
193+
fn get_name(&self) -> String {
194+
Self::NODE_NAME.to_string()
195+
}
196+
}
197+
198+
impl DistributedPipelineNode for BroadcastJoinNode {
199+
fn context(&self) -> &PipelineNodeContext {
200+
&self.context
201+
}
202+
203+
fn config(&self) -> &PipelineNodeConfig {
204+
&self.config
205+
}
206+
207+
fn children(&self) -> Vec<Arc<dyn DistributedPipelineNode>> {
208+
vec![self.broadcaster.clone(), self.receiver.clone()]
209+
}
210+
211+
fn produce_tasks(
212+
self: Arc<Self>,
213+
stage_context: &mut StageExecutionContext,
214+
) -> SubmittableTaskStream {
215+
let broadcaster_input = self.broadcaster.clone().produce_tasks(stage_context);
216+
let receiver_input = self.receiver.clone().produce_tasks(stage_context);
217+
218+
let (result_tx, result_rx) = create_channel(1);
219+
let execution_loop = self.execution_loop(
220+
broadcaster_input,
221+
receiver_input,
222+
stage_context.task_id_counter(),
223+
result_tx,
224+
stage_context.scheduler_handle(),
225+
);
226+
stage_context.spawn(execution_loop);
227+
228+
SubmittableTaskStream::from(result_rx)
229+
}
230+
231+
fn as_tree_display(&self) -> &dyn TreeDisplay {
232+
self
233+
}
234+
}

src/daft-distributed/src/pipeline_node/hash_join.rs renamed to src/daft-distributed/src/pipeline_node/join/hash_join.rs

Lines changed: 5 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@ use std::{cmp::max, sync::Arc};
22

33
use common_daft_config::DaftExecutionConfig;
44
use common_display::{tree::TreeDisplay, DisplayLevel};
5-
use common_error::DaftResult;
6-
use daft_dsl::{expr::bound_expr::BoundExpr, join::normalize_join_keys};
5+
use daft_dsl::expr::bound_expr::BoundExpr;
76
use daft_local_plan::LocalPhysicalPlan;
87
use daft_logical_plan::{
9-
ops::join::JoinPredicate, partitioning::HashClusteringConfig, stats::StatsState,
10-
ClusteringSpec, JoinType,
8+
partitioning::HashClusteringConfig, stats::StatsState, ClusteringSpec, JoinType,
119
};
1210
use daft_schema::schema::SchemaRef;
1311
use futures::StreamExt;
1412

15-
use super::{DistributedPipelineNode, SubmittableTaskStream};
1613
use crate::{
1714
pipeline_node::{
18-
repartition::RepartitionNode, translate::LogicalPlanToPipelineNodeTranslator, NodeID,
19-
NodeName, PipelineNodeConfig, PipelineNodeContext,
15+
DistributedPipelineNode, NodeID, NodeName, PipelineNodeConfig, PipelineNodeContext,
16+
SubmittableTaskStream,
2017
},
2118
scheduling::{
2219
scheduler::SubmittableTask,
@@ -201,7 +198,7 @@ impl DistributedPipelineNode for HashJoinNode {
201198
}
202199
}
203200

204-
fn gen_num_partitions(
201+
pub(crate) fn gen_num_partitions(
205202
left_spec: &ClusteringSpec,
206203
right_spec: &ClusteringSpec,
207204
cfg: &DaftExecutionConfig,
@@ -228,79 +225,3 @@ fn gen_num_partitions(
228225
(_, _, a, b) => max(a, b),
229226
}
230227
}
231-
232-
impl LogicalPlanToPipelineNodeTranslator {
233-
pub(crate) fn gen_hash_join_nodes(
234-
&mut self,
235-
logical_node_id: Option<NodeID>,
236-
join_on: JoinPredicate,
237-
238-
left: Arc<dyn DistributedPipelineNode>,
239-
right: Arc<dyn DistributedPipelineNode>,
240-
241-
join_type: JoinType,
242-
output_schema: SchemaRef,
243-
) -> DaftResult<Arc<dyn DistributedPipelineNode>> {
244-
let (_, left_on, right_on, null_equals_nulls) = join_on.split_eq_preds();
245-
246-
let (left_on, right_on) = normalize_join_keys(
247-
left_on,
248-
right_on,
249-
left.config().schema.clone(),
250-
right.config().schema.clone(),
251-
)?;
252-
let left_on = BoundExpr::bind_all(&left_on, &left.config().schema)?;
253-
let right_on = BoundExpr::bind_all(&right_on, &right.config().schema)?;
254-
255-
let num_partitions = gen_num_partitions(
256-
left.config().clustering_spec.as_ref(),
257-
right.config().clustering_spec.as_ref(),
258-
self.stage_config.config.as_ref(),
259-
);
260-
261-
let left = RepartitionNode::new(
262-
self.get_next_pipeline_node_id(),
263-
logical_node_id,
264-
&self.stage_config,
265-
daft_logical_plan::partitioning::RepartitionSpec::Hash(
266-
daft_logical_plan::partitioning::HashRepartitionConfig::new(
267-
Some(num_partitions),
268-
left_on.clone().into_iter().map(|e| e.into()).collect(),
269-
),
270-
),
271-
left.config().schema.clone(),
272-
left,
273-
)
274-
.arced();
275-
276-
let right = RepartitionNode::new(
277-
self.get_next_pipeline_node_id(),
278-
logical_node_id,
279-
&self.stage_config,
280-
daft_logical_plan::partitioning::RepartitionSpec::Hash(
281-
daft_logical_plan::partitioning::HashRepartitionConfig::new(
282-
Some(num_partitions),
283-
right_on.clone().into_iter().map(|e| e.into()).collect(),
284-
),
285-
),
286-
right.config().schema.clone(),
287-
right,
288-
)
289-
.arced();
290-
291-
Ok(HashJoinNode::new(
292-
self.get_next_pipeline_node_id(),
293-
logical_node_id,
294-
&self.stage_config,
295-
left_on,
296-
right_on,
297-
Some(null_equals_nulls),
298-
join_type,
299-
num_partitions,
300-
left,
301-
right,
302-
output_schema,
303-
)
304-
.arced())
305-
}
306-
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
mod broadcast_join;
2+
pub(crate) mod hash_join;
3+
pub(crate) mod translate_join;
4+
5+
pub(crate) use broadcast_join::BroadcastJoinNode;
6+
pub(crate) use hash_join::HashJoinNode;

0 commit comments

Comments
 (0)