Skip to content

Commit 9b7a41c

Browse files
committed
rich unit test
1 parent b2e92ac commit 9b7a41c

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::physical_optimizer::test_utils::{
2525
};
2626
use crate::physical_optimizer::test_utils::{parquet_exec_with_sort, trim_plan_display};
2727

28+
use crate::sql::ExplainNormalizer;
2829
use arrow::compute::SortOptions;
2930
use datafusion::config::ConfigOptions;
3031
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -3169,9 +3170,11 @@ async fn apply_enforce_distribution_multiple_times() -> Result<()> {
31693170
// Create a configuration
31703171
let config = SessionConfig::new();
31713172
let ctx = SessionContext::new_with_config(config);
3172-
3173+
let testdata = datafusion::test_util::arrow_test_data();
3174+
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");
31733175
// Create table schema and data
3174-
let sql = "CREATE EXTERNAL TABLE aggregate_test_100 (
3176+
let sql = format!(
3177+
"CREATE EXTERNAL TABLE aggregate_test_100 (
31753178
c1 VARCHAR NOT NULL,
31763179
c2 TINYINT NOT NULL,
31773180
c3 SMALLINT NOT NULL,
@@ -3187,10 +3190,11 @@ async fn apply_enforce_distribution_multiple_times() -> Result<()> {
31873190
c13 VARCHAR NOT NULL
31883191
)
31893192
STORED AS CSV
3190-
LOCATION '../../testing/data/csv/aggregate_test_100.csv'
3191-
OPTIONS ('format.has_header' 'true')";
3193+
LOCATION '{csv_file}'
3194+
OPTIONS ('format.has_header' 'true')"
3195+
);
31923196

3193-
ctx.sql(sql).await?;
3197+
ctx.sql(sql.as_str()).await?;
31943198

31953199
let df = ctx.sql("SELECT * FROM(SELECT * FROM aggregate_test_100 UNION ALL SELECT * FROM aggregate_test_100) ORDER BY c13 LIMIT 5").await?;
31963200
let logical_plan = df.logical_plan().clone();
@@ -3228,12 +3232,33 @@ async fn apply_enforce_distribution_multiple_times() -> Result<()> {
32283232
let optimized_physical_plan = planner
32293233
.create_physical_plan(&optimized_logical_plan, &session_state)
32303234
.await?;
3235+
let normalizer = ExplainNormalizer::new();
3236+
let actual = format!(
3237+
"{}",
3238+
displayable(optimized_physical_plan.as_ref()).indent(true)
3239+
)
3240+
.trim()
3241+
.lines()
3242+
// normalize paths
3243+
.map(|s| normalizer.normalize(s))
3244+
.collect::<Vec<_>>();
3245+
// Test the optimized plan is correct (after twice `EnforceDistribution`)
3246+
// The `fetch` is maintained after the second `EnforceDistribution`
3247+
let expected = vec![
3248+
"SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]",
3249+
" CoalescePartitionsExec",
3250+
" UnionExec",
3251+
" SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]",
3252+
" DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], file_type=csv, has_header=true",
3253+
" SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]",
3254+
" DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], file_type=csv, has_header=true",
3255+
];
3256+
assert_eq!(
3257+
expected, actual,
3258+
"expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n"
3259+
);
32313260

3232-
// println!("{}", displayable(optimized_physical_plan.as_ref()).indent(true));
3233-
3234-
let mut results = optimized_physical_plan
3235-
.execute(0, ctx.task_ctx().clone())
3236-
.unwrap();
3261+
let mut results = optimized_physical_plan.execute(0, ctx.task_ctx().clone())?;
32373262

32383263
let batch = results.next().await.unwrap()?;
32393264
// Without the fix of https://github.com/apache/datafusion/pull/14207, the number of rows will be 10

datafusion/core/tests/sql/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ pub struct ExplainNormalizer {
243243
}
244244

245245
impl ExplainNormalizer {
246-
fn new() -> Self {
246+
pub(crate) fn new() -> Self {
247247
let mut replacements = vec![];
248248

249249
let mut push_path = |path: PathBuf, key: &str| {
@@ -266,7 +266,7 @@ impl ExplainNormalizer {
266266
Self { replacements }
267267
}
268268

269-
fn normalize(&self, s: impl Into<String>) -> String {
269+
pub(crate) fn normalize(&self, s: impl Into<String>) -> String {
270270
let mut s = s.into();
271271
for (from, to) in &self.replacements {
272272
s = s.replace(from, to);

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,25 +1390,27 @@ pub fn ensure_distribution(
13901390
} else {
13911391
plan.with_new_children(children_plans)?
13921392
};
1393+
let mut optimized_distribution_ctx =
1394+
DistributionContext::new(Arc::clone(&plan), data.clone(), children);
13931395

13941396
// If `fetch` was not consumed, it means that there was `SortPreservingMergeExec` with fetch before
13951397
// It was removed by `remove_dist_changing_operators`
13961398
// and we need to add it back.
13971399
if fetch.is_some() {
1398-
plan = Arc::new(
1400+
let plan = Arc::new(
13991401
SortPreservingMergeExec::new(
14001402
plan.output_ordering()
14011403
.unwrap_or(&LexOrdering::default())
14021404
.clone(),
14031405
plan,
14041406
)
14051407
.with_fetch(fetch.take()),
1406-
)
1408+
);
1409+
optimized_distribution_ctx =
1410+
DistributionContext::new(plan, data, vec![optimized_distribution_ctx]);
14071411
}
14081412

1409-
Ok(Transformed::yes(DistributionContext::new(
1410-
plan, data, children,
1411-
)))
1413+
Ok(Transformed::yes(optimized_distribution_ctx))
14121414
}
14131415

14141416
/// Distribution context that tracks distribution changing operators and fetch limits

0 commit comments

Comments
 (0)