Skip to content

Commit d23ed2a

Browse files
thinkharderdevavantgardnerio
authored andcommitted
Fix bug in swap_hash_join (#278)
* Try and fix swap_hash_join * Only swap projections when join does not have projections * just backport upstream fix * remove println
1 parent 8a222b6 commit d23ed2a

File tree

1 file changed

+60
-61
lines changed

1 file changed

+60
-61
lines changed

datafusion/core/src/physical_optimizer/join_selection.rs

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
//! pipeline-friendly ones. To achieve the second goal, it selects the proper
2424
//! `PartitionMode` and the build side using the available statistics for hash joins.
2525
26-
use std::sync::Arc;
27-
2826
use crate::config::ConfigOptions;
2927
use crate::error::Result;
3028
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
@@ -34,6 +32,7 @@ use crate::physical_plan::joins::{
3432
};
3533
use crate::physical_plan::projection::ProjectionExec;
3634
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
35+
use std::sync::Arc;
3736

3837
use arrow_schema::Schema;
3938
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -1178,6 +1177,65 @@ mod tests_statistical {
11781177
);
11791178
}
11801179

1180+
#[rstest(
1181+
join_type, projection, small_on_right,
1182+
case::inner(JoinType::Inner, vec![1], true),
1183+
case::left(JoinType::Left, vec![1], true),
1184+
case::right(JoinType::Right, vec![1], true),
1185+
case::full(JoinType::Full, vec![1], true),
1186+
case::left_anti(JoinType::LeftAnti, vec![0], false),
1187+
case::left_semi(JoinType::LeftSemi, vec![0], false),
1188+
case::right_anti(JoinType::RightAnti, vec![0], true),
1189+
case::right_semi(JoinType::RightSemi, vec![0], true),
1190+
)]
1191+
#[tokio::test]
1192+
async fn test_hash_join_swap_on_joins_with_projections(
1193+
join_type: JoinType,
1194+
projection: Vec<usize>,
1195+
small_on_right: bool,
1196+
) -> Result<()> {
1197+
let (big, small) = create_big_and_small();
1198+
1199+
let left = if small_on_right { &big } else { &small };
1200+
let right = if small_on_right { &small } else { &big };
1201+
1202+
let left_on = if small_on_right {
1203+
"big_col"
1204+
} else {
1205+
"small_col"
1206+
};
1207+
let right_on = if small_on_right {
1208+
"small_col"
1209+
} else {
1210+
"big_col"
1211+
};
1212+
1213+
let join = Arc::new(HashJoinExec::try_new(
1214+
Arc::clone(left),
1215+
Arc::clone(right),
1216+
vec![(
1217+
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
1218+
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
1219+
)],
1220+
None,
1221+
&join_type,
1222+
Some(projection),
1223+
PartitionMode::Partitioned,
1224+
false,
1225+
)?);
1226+
1227+
let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
1228+
.expect("swap_hash_join must support joins with projections");
1229+
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
1230+
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
1231+
);
1232+
1233+
assert_eq!(swapped_join.projection, Some(vec![0_usize]));
1234+
assert_eq!(swapped.schema().fields.len(), 1);
1235+
assert_eq!(swapped.schema().fields[0].name(), "small_col");
1236+
Ok(())
1237+
}
1238+
11811239
#[rstest(
11821240
join_type,
11831241
case::inner(JoinType::Inner),
@@ -1307,65 +1365,6 @@ mod tests_statistical {
13071365
);
13081366
}
13091367

1310-
#[rstest(
1311-
join_type, projection, small_on_right,
1312-
case::inner(JoinType::Inner, vec![1], true),
1313-
case::left(JoinType::Left, vec![1], true),
1314-
case::right(JoinType::Right, vec![1], true),
1315-
case::full(JoinType::Full, vec![1], true),
1316-
case::left_anti(JoinType::LeftAnti, vec![0], false),
1317-
case::left_semi(JoinType::LeftSemi, vec![0], false),
1318-
case::right_anti(JoinType::RightAnti, vec![0], true),
1319-
case::right_semi(JoinType::RightSemi, vec![0], true),
1320-
)]
1321-
#[tokio::test]
1322-
async fn test_hash_join_swap_on_joins_with_projections(
1323-
join_type: JoinType,
1324-
projection: Vec<usize>,
1325-
small_on_right: bool,
1326-
) -> Result<()> {
1327-
let (big, small) = create_big_and_small();
1328-
1329-
let left = if small_on_right { &big } else { &small };
1330-
let right = if small_on_right { &small } else { &big };
1331-
1332-
let left_on = if small_on_right {
1333-
"big_col"
1334-
} else {
1335-
"small_col"
1336-
};
1337-
let right_on = if small_on_right {
1338-
"small_col"
1339-
} else {
1340-
"big_col"
1341-
};
1342-
1343-
let join = Arc::new(HashJoinExec::try_new(
1344-
Arc::clone(left),
1345-
Arc::clone(right),
1346-
vec![(
1347-
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
1348-
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
1349-
)],
1350-
None,
1351-
&join_type,
1352-
Some(projection),
1353-
PartitionMode::Partitioned,
1354-
false,
1355-
)?);
1356-
1357-
let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
1358-
.expect("swap_hash_join must support joins with projections");
1359-
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
1360-
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
1361-
);
1362-
1363-
assert_eq!(swapped_join.projection, Some(vec![0_usize]));
1364-
assert_eq!(swapped.schema().fields.len(), 1);
1365-
assert_eq!(swapped.schema().fields[0].name(), "small_col");
1366-
Ok(())
1367-
}
1368-
13691368
#[tokio::test]
13701369
async fn test_swap_reverting_projection() {
13711370
let left_schema = Schema::new(vec![

0 commit comments

Comments
 (0)