|
| 1 | +use std::{collections::HashMap, sync::Arc}; |
| 2 | + |
| 3 | +use common_display::{tree::TreeDisplay, DisplayLevel}; |
| 4 | +use common_error::DaftResult; |
| 5 | +use daft_dsl::expr::bound_expr::BoundExpr; |
| 6 | +use daft_local_plan::LocalPhysicalPlan; |
| 7 | +use daft_logical_plan::{stats::StatsState, JoinType}; |
| 8 | +use daft_schema::schema::SchemaRef; |
| 9 | +use futures::{StreamExt, TryStreamExt}; |
| 10 | + |
| 11 | +use crate::{ |
| 12 | + pipeline_node::{ |
| 13 | + make_in_memory_scan_from_materialized_outputs, DistributedPipelineNode, NodeID, NodeName, |
| 14 | + PipelineNodeConfig, PipelineNodeContext, SubmittableTaskStream, |
| 15 | + }, |
| 16 | + scheduling::{ |
| 17 | + scheduler::{SchedulerHandle, SubmittableTask}, |
| 18 | + task::{SchedulingStrategy, SwordfishTask, TaskContext}, |
| 19 | + }, |
| 20 | + stage::{StageConfig, StageExecutionContext, TaskIDCounter}, |
| 21 | + utils::channel::{create_channel, Sender}, |
| 22 | +}; |
| 23 | + |
| 24 | +pub(crate) struct BroadcastJoinNode { |
| 25 | + config: PipelineNodeConfig, |
| 26 | + context: PipelineNodeContext, |
| 27 | + |
| 28 | + // Join properties |
| 29 | + left_on: Vec<BoundExpr>, |
| 30 | + right_on: Vec<BoundExpr>, |
| 31 | + null_equals_nulls: Option<Vec<bool>>, |
| 32 | + join_type: JoinType, |
| 33 | + is_swapped: bool, |
| 34 | + |
| 35 | + broadcaster: Arc<dyn DistributedPipelineNode>, |
| 36 | + broadcaster_schema: SchemaRef, |
| 37 | + receiver: Arc<dyn DistributedPipelineNode>, |
| 38 | +} |
| 39 | + |
| 40 | +impl BroadcastJoinNode { |
| 41 | + const NODE_NAME: NodeName = "BroadcastJoin"; |
| 42 | + |
| 43 | + #[allow(clippy::too_many_arguments)] |
| 44 | + pub fn new( |
| 45 | + node_id: NodeID, |
| 46 | + logical_node_id: Option<NodeID>, |
| 47 | + stage_config: &StageConfig, |
| 48 | + left_on: Vec<BoundExpr>, |
| 49 | + right_on: Vec<BoundExpr>, |
| 50 | + null_equals_nulls: Option<Vec<bool>>, |
| 51 | + join_type: JoinType, |
| 52 | + is_swapped: bool, |
| 53 | + broadcaster: Arc<dyn DistributedPipelineNode>, |
| 54 | + receiver: Arc<dyn DistributedPipelineNode>, |
| 55 | + output_schema: SchemaRef, |
| 56 | + ) -> Self { |
| 57 | + let context = PipelineNodeContext::new( |
| 58 | + stage_config, |
| 59 | + node_id, |
| 60 | + Self::NODE_NAME, |
| 61 | + vec![broadcaster.node_id(), receiver.node_id()], |
| 62 | + vec![broadcaster.name(), receiver.name()], |
| 63 | + logical_node_id, |
| 64 | + ); |
| 65 | + |
| 66 | + // For broadcast joins, we use the receiver's clustering spec since the broadcaster |
| 67 | + // will be gathered to all partitions |
| 68 | + let config = PipelineNodeConfig::new( |
| 69 | + output_schema, |
| 70 | + stage_config.config.clone(), |
| 71 | + receiver.config().clustering_spec.clone(), |
| 72 | + ); |
| 73 | + |
| 74 | + let broadcaster_schema = broadcaster.config().schema.clone(); |
| 75 | + Self { |
| 76 | + config, |
| 77 | + context, |
| 78 | + left_on, |
| 79 | + right_on, |
| 80 | + null_equals_nulls, |
| 81 | + join_type, |
| 82 | + is_swapped, |
| 83 | + broadcaster, |
| 84 | + broadcaster_schema, |
| 85 | + receiver, |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + pub fn arced(self) -> Arc<dyn DistributedPipelineNode> { |
| 90 | + Arc::new(self) |
| 91 | + } |
| 92 | + |
| 93 | + fn multiline_display(&self) -> Vec<String> { |
| 94 | + use itertools::Itertools; |
| 95 | + let mut res = vec!["Broadcast Join".to_string()]; |
| 96 | + res.push(format!( |
| 97 | + "Left on: {}", |
| 98 | + self.left_on.iter().map(|e| e.to_string()).join(", ") |
| 99 | + )); |
| 100 | + res.push(format!( |
| 101 | + "Right on: {}", |
| 102 | + self.right_on.iter().map(|e| e.to_string()).join(", ") |
| 103 | + )); |
| 104 | + res.push(format!("Join type: {}", self.join_type)); |
| 105 | + res.push(format!("Is swapped: {}", self.is_swapped)); |
| 106 | + if let Some(null_equals_nulls) = &self.null_equals_nulls { |
| 107 | + res.push(format!( |
| 108 | + "Null equals nulls: [{}]", |
| 109 | + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") |
| 110 | + )); |
| 111 | + } |
| 112 | + res |
| 113 | + } |
| 114 | + |
| 115 | + async fn execution_loop( |
| 116 | + self: Arc<Self>, |
| 117 | + broadcaster_input: SubmittableTaskStream, |
| 118 | + mut receiver_input: SubmittableTaskStream, |
| 119 | + task_id_counter: TaskIDCounter, |
| 120 | + result_tx: Sender<SubmittableTask<SwordfishTask>>, |
| 121 | + scheduler_handle: SchedulerHandle<SwordfishTask>, |
| 122 | + ) -> DaftResult<()> { |
| 123 | + let materialized_broadcast_data = broadcaster_input |
| 124 | + .materialize(scheduler_handle.clone()) |
| 125 | + .try_collect::<Vec<_>>() |
| 126 | + .await?; |
| 127 | + let materialized_broadcast_data_plan = make_in_memory_scan_from_materialized_outputs( |
| 128 | + &materialized_broadcast_data, |
| 129 | + self.broadcaster_schema.clone(), |
| 130 | + self.node_id(), |
| 131 | + )?; |
| 132 | + let broadcast_psets = HashMap::from([( |
| 133 | + self.node_id().to_string(), |
| 134 | + materialized_broadcast_data |
| 135 | + .into_iter() |
| 136 | + .flat_map(|output| output.into_inner().0) |
| 137 | + .collect::<Vec<_>>(), |
| 138 | + )]); |
| 139 | + while let Some(task) = receiver_input.next().await { |
| 140 | + let input_plan = task.task().plan(); |
| 141 | + let (left_plan, right_plan) = if self.is_swapped { |
| 142 | + (input_plan, materialized_broadcast_data_plan.clone()) |
| 143 | + } else { |
| 144 | + (materialized_broadcast_data_plan.clone(), input_plan) |
| 145 | + }; |
| 146 | + let join_plan = LocalPhysicalPlan::hash_join( |
| 147 | + left_plan, |
| 148 | + right_plan, |
| 149 | + self.left_on.clone(), |
| 150 | + self.right_on.clone(), |
| 151 | + self.null_equals_nulls.clone(), |
| 152 | + self.join_type, |
| 153 | + self.config.schema.clone(), |
| 154 | + StatsState::NotMaterialized, |
| 155 | + ); |
| 156 | + |
| 157 | + let mut psets = task.task().psets().clone(); |
| 158 | + psets.extend(broadcast_psets.clone()); |
| 159 | + |
| 160 | + let config = task.task().config().clone(); |
| 161 | + |
| 162 | + let task = task.with_new_task(SwordfishTask::new( |
| 163 | + TaskContext::from((self.context(), task_id_counter.next())), |
| 164 | + join_plan, |
| 165 | + config, |
| 166 | + psets, |
| 167 | + SchedulingStrategy::Spread, |
| 168 | + self.context().to_hashmap(), |
| 169 | + )); |
| 170 | + if result_tx.send(task).await.is_err() { |
| 171 | + break; |
| 172 | + } |
| 173 | + } |
| 174 | + Ok(()) |
| 175 | + } |
| 176 | +} |
| 177 | + |
| 178 | +impl TreeDisplay for BroadcastJoinNode { |
| 179 | + fn display_as(&self, level: DisplayLevel) -> String { |
| 180 | + match level { |
| 181 | + DisplayLevel::Compact => self.get_name(), |
| 182 | + _ => self.multiline_display().join("\n"), |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + fn get_children(&self) -> Vec<&dyn TreeDisplay> { |
| 187 | + vec![ |
| 188 | + self.broadcaster.as_tree_display(), |
| 189 | + self.receiver.as_tree_display(), |
| 190 | + ] |
| 191 | + } |
| 192 | + |
| 193 | + fn get_name(&self) -> String { |
| 194 | + Self::NODE_NAME.to_string() |
| 195 | + } |
| 196 | +} |
| 197 | + |
| 198 | +impl DistributedPipelineNode for BroadcastJoinNode { |
| 199 | + fn context(&self) -> &PipelineNodeContext { |
| 200 | + &self.context |
| 201 | + } |
| 202 | + |
| 203 | + fn config(&self) -> &PipelineNodeConfig { |
| 204 | + &self.config |
| 205 | + } |
| 206 | + |
| 207 | + fn children(&self) -> Vec<Arc<dyn DistributedPipelineNode>> { |
| 208 | + vec![self.broadcaster.clone(), self.receiver.clone()] |
| 209 | + } |
| 210 | + |
| 211 | + fn produce_tasks( |
| 212 | + self: Arc<Self>, |
| 213 | + stage_context: &mut StageExecutionContext, |
| 214 | + ) -> SubmittableTaskStream { |
| 215 | + let broadcaster_input = self.broadcaster.clone().produce_tasks(stage_context); |
| 216 | + let receiver_input = self.receiver.clone().produce_tasks(stage_context); |
| 217 | + |
| 218 | + let (result_tx, result_rx) = create_channel(1); |
| 219 | + let execution_loop = self.execution_loop( |
| 220 | + broadcaster_input, |
| 221 | + receiver_input, |
| 222 | + stage_context.task_id_counter(), |
| 223 | + result_tx, |
| 224 | + stage_context.scheduler_handle(), |
| 225 | + ); |
| 226 | + stage_context.spawn(execution_loop); |
| 227 | + |
| 228 | + SubmittableTaskStream::from(result_rx) |
| 229 | + } |
| 230 | + |
| 231 | + fn as_tree_display(&self) -> &dyn TreeDisplay { |
| 232 | + self |
| 233 | + } |
| 234 | +} |
0 commit comments