Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 2 additions & 33 deletions src/daft-distributed/src/pipeline_node/join/hash_join.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use std::{cmp::max, sync::Arc};
use std::sync::Arc;

use common_daft_config::DaftExecutionConfig;
use common_display::{tree::TreeDisplay, DisplayLevel};
use daft_dsl::expr::bound_expr::BoundExpr;
use daft_local_plan::LocalPhysicalPlan;
use daft_logical_plan::{
partitioning::HashClusteringConfig, stats::StatsState, ClusteringSpec, JoinType,
};
use daft_logical_plan::{partitioning::HashClusteringConfig, stats::StatsState, JoinType};
use daft_schema::schema::SchemaRef;
use futures::StreamExt;

Expand Down Expand Up @@ -197,31 +194,3 @@ impl DistributedPipelineNode for HashJoinNode {
self
}
}

pub(crate) fn gen_num_partitions(
left_spec: &ClusteringSpec,
right_spec: &ClusteringSpec,
cfg: &DaftExecutionConfig,
) -> usize {
let is_left_hash_partitioned = matches!(left_spec, ClusteringSpec::Hash(_));
let is_right_hash_partitioned = matches!(right_spec, ClusteringSpec::Hash(_));
let num_left_partitions = left_spec.num_partitions();
let num_right_partitions = right_spec.num_partitions();

match (
is_left_hash_partitioned,
is_right_hash_partitioned,
num_left_partitions,
num_right_partitions,
) {
(true, true, a, b) | (false, false, a, b) => max(a, b),
(_, _, 1, x) | (_, _, x, 1) => x,
(true, false, a, b) if (a as f64) >= (b as f64) * cfg.hash_join_partition_size_leniency => {
a
}
(false, true, a, b) if (b as f64) >= (a as f64) * cfg.hash_join_partition_size_leniency => {
b
}
(_, _, a, b) => max(a, b),
}
}
98 changes: 71 additions & 27 deletions src/daft-distributed/src/pipeline_node/join/translate_join.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::sync::Arc;
use std::{cmp::max, sync::Arc};

use common_error::DaftResult;
use daft_dsl::{expr::bound_expr::BoundExpr, ExprRef};
use daft_dsl::{expr::bound_expr::BoundExpr, is_partition_compatible, ExprRef};
use daft_logical_plan::{
ops::Join,
partitioning::{HashRepartitionConfig, RepartitionSpec},
stats::ApproxStats,
JoinStrategy, JoinType,
ClusteringSpec, JoinStrategy, JoinType,
};
use daft_schema::schema::SchemaRef;

use super::hash_join::gen_num_partitions;
use crate::pipeline_node::{
join::{BroadcastJoinNode, HashJoinNode},
translate::LogicalPlanToPipelineNodeTranslator,
Expand Down Expand Up @@ -75,31 +74,76 @@ impl LogicalPlanToPipelineNodeTranslator {
join_type: JoinType,
output_schema: SchemaRef,
) -> DaftResult<Arc<dyn DistributedPipelineNode>> {
let num_partitions = gen_num_partitions(
left.config().clustering_spec.as_ref(),
right.config().clustering_spec.as_ref(),
self.stage_config.config.as_ref(),
);
let left_spec = left.config().clustering_spec.as_ref();
let right_spec = right.config().clustering_spec.as_ref();

let left = self.gen_shuffle_node(
logical_node_id,
RepartitionSpec::Hash(HashRepartitionConfig::new(
Some(num_partitions),
left_on.iter().map(|e| e.clone().into()).collect(),
)),
left.config().schema.clone(),
left,
)?;
let is_left_hash_partitioned = matches!(left_spec, ClusteringSpec::Hash(..))
&& is_partition_compatible(
&left_spec.partition_by(),
left_on.iter().map(|e| e.inner()),
);
let is_right_hash_partitioned = matches!(right_spec, ClusteringSpec::Hash(..))
&& is_partition_compatible(
&right_spec.partition_by(),
right_on.iter().map(|e| e.inner()),
);
let num_left_partitions = left_spec.num_partitions();
let num_right_partitions = right_spec.num_partitions();

let right = self.gen_shuffle_node(
logical_node_id,
RepartitionSpec::Hash(HashRepartitionConfig::new(
Some(num_partitions),
right_on.iter().map(|e| e.clone().into()).collect(),
)),
right.config().schema.clone(),
right,
)?;
let num_partitions = match (
is_left_hash_partitioned,
is_right_hash_partitioned,
num_left_partitions,
num_right_partitions,
) {
(true, true, a, b) | (false, false, a, b) => max(a, b),
(_, _, 1, x) | (_, _, x, 1) => x,
(true, false, a, b)
if (a as f64)
>= (b as f64) * self.stage_config.config.hash_join_partition_size_leniency =>
{
a
}
(false, true, a, b)
if (b as f64)
>= (a as f64) * self.stage_config.config.hash_join_partition_size_leniency =>
{
b
}
(_, _, a, b) => max(a, b),
};

let left = if num_left_partitions != num_partitions
|| (num_partitions > 1 && !is_left_hash_partitioned)
{
self.gen_shuffle_node(
logical_node_id,
RepartitionSpec::Hash(HashRepartitionConfig::new(
Some(num_partitions),
left_on.iter().map(|e| e.clone().into()).collect(),
)),
left.config().schema.clone(),
left,
)?
} else {
left
};

let right = if num_right_partitions != num_partitions
|| (num_partitions > 1 && !is_right_hash_partitioned)
{
self.gen_shuffle_node(
logical_node_id,
RepartitionSpec::Hash(HashRepartitionConfig::new(
Some(num_partitions),
right_on.iter().map(|e| e.clone().into()).collect(),
)),
right.config().schema.clone(),
right,
)?
} else {
right
};

Ok(HashJoinNode::new(
self.get_next_pipeline_node_id(),
Expand Down
Loading