diff --git a/benchmarks/README.md b/benchmarks/README.md index c84af9a7e6f37..0b71628b2db12 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -790,7 +790,7 @@ Different queries are included to test nested loop joins under various workloads ## Hash Join -This benchmark focuses on the performance of queries with nested hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. +This benchmark focuses on the performance of queries with hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. Several queries are included to test hash joins under various workloads. @@ -802,6 +802,19 @@ Several queries are included to test hash joins under various workloads. ./bench.sh run hj ``` +## Sort Merge Join + +This benchmark focuses on the performance of queries with sort merge joins joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test sort merge joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run smj +``` ## Cancellation Test performance of cancelling queries. diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index a91a6c32bd425..975f4ec08fa88 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -132,6 +132,7 @@ imdb: Join Order Benchmark (JOB) using the IMDB dataset conver cancellation: How long cancelling a query takes nlj: Benchmark for simple nested loop joins, testing various join scenarios hj: Benchmark for simple hash joins, testing various join scenarios +smj: Benchmark for simple sort merge joins, testing various join scenarios compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size @@ -324,6 +325,10 @@ main() { # hj uses range() function, no data generation needed echo "HJ benchmark does not require data generation" ;; + smj) + # smj uses range() function, no data generation needed + echo "SMJ benchmark does not require data generation" + ;; compile_profile) data_tpch "1" "parquet" ;; @@ -401,6 +406,7 @@ main() { run_nlj run_hj run_tpcds + run_smj ;; tpch) run_tpch "1" "parquet" @@ -514,6 +520,9 @@ main() { hj) run_hj ;; + smj) + run_smj + ;; compile_profile) run_compile_profile "${PROFILE_ARGS[@]}" ;; @@ -1234,6 +1243,14 @@ run_hj() { debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} } +# Runs the smj benchmark +run_smj() { + RESULTS_FILE="${RESULTS_DIR}/smj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running smj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- smj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} +} + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 6fc382822d9ba..d842d306c1f65 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -34,7 +34,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, hj, imdb, nlj, sort_tpch, tpcds, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, smj, sort_tpch, tpcds, tpch, }; #[derive(Debug, StructOpt)] @@ -46,6 +46,7 @@ enum Options { HJ(hj::RunOpt), Imdb(imdb::RunOpt), Nlj(nlj::RunOpt), + Smj(smj::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), Tpcds(tpcds::RunOpt), @@ -63,6 +64,7 @@ pub async fn main() -> Result<()> { Options::HJ(opt) => opt.run().await, Options::Imdb(opt) => Box::pin(opt.run()).await, Options::Nlj(opt) => opt.run().await, + Options::Smj(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, Options::Tpch(opt) => Box::pin(opt.run()).await, Options::Tpcds(opt) => Box::pin(opt.run()).await, diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index d885ec94a306c..a3bc221840ada 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -22,6 +22,7 @@ pub mod h2o; pub mod hj; pub mod imdb; pub mod nlj; +pub mod smj; pub mod sort_tpch; pub mod tpcds; pub mod tpch; diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs new file mode 100644 index 0000000000000..32a620a12d4fb --- /dev/null +++ b/benchmarks/src/smj.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use structopt::StructOpt; + +use futures::StreamExt; + +/// Run the Sort Merge Join (SMJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of SMJs. +/// +/// It uses equality join predicates (to ensure SMJ is selected) and varies: +/// - Join type: Inner/Left/Right/Full/LeftSemi/LeftAnti/RightSemi/RightAnti +/// - Key cardinality: 1:1, 1:N, N:M relationships +/// - Filter selectivity: Low (1%), Medium (10%), High (50%) +/// - Input sizes: Small to large, balanced and skewed +/// +/// All inputs are pre-sorted in CTEs before the join to isolate join +/// performance from sort overhead. +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 20). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +/// Inline SQL queries for SMJ benchmarks +/// +/// Each query's comment includes: +/// - Join type +/// - Left row count × Right row count +/// - Key cardinality (rows per key) +/// - Filter selectivity (if applicable) +const SMJ_QUERIES: &[&str] = &[ + // Q1: INNER 100K x 100K | 1:1 + r#" + WITH t1_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t2_sorted.key as k2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q2: INNER 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q3: INNER 1M x 1M | 1:100 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q4: INNER 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data % 100 = 0 + "#, + // Q5: INNER 1M x 1M | 1:100 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t1_sorted.data <> t2_sorted.data AND t2_sorted.data % 10 = 0 + "#, + // Q6: LEFT 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q7: LEFT 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 + "#, + // Q8: FULL 100K x 100K | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 12500 as key, value as data + FROM range(100000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q9: FULL 100K x 1M | 1:10 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE (t1_sorted.data IS NULL OR t2_sorted.data IS NULL + OR t1_sorted.data <> t2_sorted.data) + AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) + "#, + // Q10: LEFT SEMI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q11: LEFT SEMI 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q12: LEFT SEMI 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q13: LEFT SEMI 100K x 1M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data % 10 <> 0 + ) + "#, + // Q14: LEFT ANTI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q15: LEFT ANTI 100K x 1M | 1:10 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 12000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q16: LEFT ANTI 100K x 100K | 1:1 | stress + r#" + WITH t1_sorted AS ( + SELECT value % 11000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(100000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q17: INNER 100K x 5M | 1:50 | 5% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data <> t1_sorted.data AND t2_sorted.data % 20 = 0 + "#, + // Q18: LEFT SEMI 100K x 5M | 1:50 | 2% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 50 = 0 + ) + "#, + // Q19: LEFT ANTI 100K x 5M | 1:50 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 15000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(5000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q20: INNER 1M x 10M | 1:100 + GROUP BY + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, count(*) as cnt + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + GROUP BY t1_sorted.key + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running SMJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= SMJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + SMJ_QUERIES.len() + ); + } + } + None => 1..=SMJ_QUERIES.len(), + }; + + let mut config = self.common.config()?; + // Disable hash joins to force SMJ + config = config.set_bool("datafusion.optimizer.prefer_hash_join", false); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = SMJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("SMJ benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Sort Merge Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("SortMergeJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Sort Merge Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index e266f1b5b76c4..4119a54cd5395 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -41,7 +41,8 @@ use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{types::UInt64Type, *}; use arrow::compute::{ - self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, + self, concat_batches, filter_record_batch, is_not_null, take, BatchCoalescer, + SortOptions, }; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; @@ -155,21 +156,39 @@ impl StreamedBatch { } } + /// Number of unfrozen output pairs in this streamed batch + fn num_output_rows(&self) -> usize { + self.output_indices + .iter() + .map(|chunk| chunk.streamed_indices.len()) + .sum() + } + /// Appends new pair consisting of current streamed index and `buffered_idx` /// index of buffered batch with `buffered_batch_idx` index. fn append_output_pair( &mut self, buffered_batch_idx: Option, buffered_idx: Option, + batch_size: usize, + num_unfrozen_pairs: usize, ) { // If no current chunk exists or current chunk is not for current buffered batch, // create a new chunk if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx { + // Compute capacity only when creating a new chunk (infrequent operation). + // The capacity is the remaining space to reach batch_size. + // This should always be >= 1 since we only call this when num_unfrozen_pairs < batch_size. + debug_assert!( + batch_size > num_unfrozen_pairs, + "batch_size ({batch_size}) must be > num_unfrozen_pairs ({num_unfrozen_pairs})" + ); + let capacity = batch_size - num_unfrozen_pairs; self.output_indices.push(StreamedJoinedChunk { buffered_batch_idx, - streamed_indices: UInt64Builder::with_capacity(1), - buffered_indices: UInt64Builder::with_capacity(1), + streamed_indices: UInt64Builder::with_capacity(capacity), + buffered_indices: UInt64Builder::with_capacity(capacity), }); self.buffered_batch_idx = buffered_batch_idx; }; @@ -320,14 +339,10 @@ pub(super) struct SortMergeJoinStream { /// Current state of the stream pub state: SortMergeJoinState, /// Staging output array builders - pub staging_output_record_batches: JoinedRecordBatches, + pub joined_record_batches: JoinedRecordBatches, /// Output buffer. Currently used by filtering as it requires double buffering /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` - pub output: RecordBatch, - /// Staging output size, including output batches and staging joined results. - /// Increased when we put rows into buffer and decreased after we actually output batches. - /// Used to trigger output when sufficient rows are ready - pub output_size: usize, + pub output: BatchCoalescer, /// The comparison result of current streamed row and buffered batches pub current_ordering: Ordering, /// Manages the process of spilling and reading back intermediate data @@ -347,26 +362,199 @@ pub(super) struct SortMergeJoinStream { pub streamed_batch_counter: AtomicUsize, } -/// Joined batches with attached join filter information +/// Staging area for joined data before output +/// +/// Accumulates joined rows until either: +/// - Target batch size reached (for efficiency) +/// - Stream exhausted (flush remaining data) pub(super) struct JoinedRecordBatches { /// Joined batches. Each batch is already joined columns from left and right sources - pub batches: Vec, - /// Filter match mask for each row(matched/non-matched) - pub filter_mask: BooleanBuilder, - /// Left row indices to glue together rows in `batches` and `filter_mask` - pub row_indices: UInt64Builder, - /// Which unique batch id the row belongs to - /// It is necessary to differentiate rows that are distributed the way when they point to the same - /// row index but in not the same batches - pub batch_ids: Vec, + pub(super) joined_batches: BatchCoalescer, + /// Did each output row pass the join filter? (detect if input row found any match) + pub(super) filter_mask: BooleanBuilder, + /// Which input row (within batch) produced each output row? (for grouping by input row) + pub(super) row_indices: UInt64Builder, + /// Which input batch did each output row come from? (disambiguate row_indices) + pub(super) batch_ids: Vec, } impl JoinedRecordBatches { - fn clear(&mut self) { - self.batches.clear(); + /// Concatenates all accumulated batches into a single RecordBatch + /// + /// Must drain ALL batches from BatchCoalescer for filtered joins to ensure + /// metadata alignment when applying get_corrected_filter_mask(). + pub(super) fn concat_batches(&mut self, schema: &SchemaRef) -> Result { + self.joined_batches.finish_buffered_batch()?; + + let mut all_batches = vec![]; + while let Some(batch) = self.joined_batches.next_completed_batch() { + all_batches.push(batch); + } + + match all_batches.as_slice() { + [] => unreachable!("concat_batches called with empty BatchCoalescer"), + [single_batch] => Ok(single_batch.clone()), + multiple_batches => Ok(concat_batches(schema, multiple_batches)?), + } + } + + /// Finishes and returns the metadata arrays, clearing the builders + /// + /// Returns (row_indices, filter_mask, batch_ids_ref) + /// Note: batch_ids is returned as a reference since it's still needed in the struct + fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) { + let row_indices = self.row_indices.finish(); + let filter_mask = self.filter_mask.finish(); + (row_indices, filter_mask, &self.batch_ids) + } + + /// Clears batches without touching metadata (for early return when no filtering needed) + fn clear_batches(&mut self, schema: &SchemaRef, batch_size: usize) { + self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); + } + + /// Asserts that internal metadata arrays are consistent with each other + /// Only checks if metadata is actually being used (i.e., not all empty) + #[inline] + fn debug_assert_metadata_aligned(&self) { + // Metadata arrays should be aligned IF they're being used + // (For non-filtered joins, they may all be empty) + if self.filter_mask.len() > 0 + || self.row_indices.len() > 0 + || !self.batch_ids.is_empty() + { + debug_assert_eq!( + self.filter_mask.len(), + self.row_indices.len(), + "filter_mask and row_indices must have same length when metadata is used" + ); + debug_assert_eq!( + self.filter_mask.len(), + self.batch_ids.len(), + "filter_mask and batch_ids must have same length when metadata is used" + ); + } + } + + /// Asserts that if batches is empty, metadata is also empty + #[inline] + fn debug_assert_empty_consistency(&self) { + if self.joined_batches.is_empty() { + debug_assert_eq!( + self.filter_mask.len(), + 0, + "filter_mask should be empty when batches is empty" + ); + debug_assert_eq!( + self.row_indices.len(), + 0, + "row_indices should be empty when batches is empty" + ); + debug_assert_eq!( + self.batch_ids.len(), + 0, + "batch_ids should be empty when batches is empty" + ); + } + } + + /// Pushes a batch with null metadata (Full join null-joined rows only) + /// + /// These buffered rows had NO matching streamed rows. Since we can't group + /// by input row (no input row exists), we use null metadata as a sentinel. + /// + /// Maintains invariant: N rows → N metadata entries (nulls) + fn push_batch_with_null_metadata(&mut self, batch: RecordBatch, join_type: JoinType) { + debug_assert!( + matches!(join_type, JoinType::Full), + "push_batch_with_null_metadata should only be called for Full joins" + ); + + let num_rows = batch.num_rows(); + + self.filter_mask.append_nulls(num_rows); + self.row_indices.append_nulls(num_rows); + self.batch_ids.resize( + self.batch_ids.len() + num_rows, + 0, // batch_id = 0 for null-joined rows + ); + + self.debug_assert_metadata_aligned(); + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + /// Pushes a batch with filter metadata (filtered outer/semi/anti/mark joins) + /// + /// Deferred filtering: An input row may join with multiple buffered rows, but we + /// don't know yet if all matches failed the filter. We track metadata so + /// `get_corrected_filter_mask()` can later group by input row and decide: + /// - If any match passed: emit passing rows + /// - If all matches failed: emit null-joined row + /// + /// Maintains invariant: N rows → N metadata entries + fn push_batch_with_filter_metadata( + &mut self, + batch: RecordBatch, + row_indices: &UInt64Array, + filter_mask: &BooleanArray, + streamed_batch_id: usize, + join_type: JoinType, + ) { + debug_assert!( + matches!( + join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightMark + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + ), + "push_batch_with_filter_metadata should only be called for outer/semi/anti/mark joins that need deferred filtering" + ); + + debug_assert_eq!( + row_indices.len(), + filter_mask.len(), + "row_indices and filter_mask must have same length" + ); + + // For Full joins, we keep the pre_mask (with nulls), for others we keep the cleaned mask + self.filter_mask.extend(filter_mask); + self.row_indices.extend(row_indices); + self.batch_ids + .resize(self.batch_ids.len() + row_indices.len(), streamed_batch_id); + + self.debug_assert_metadata_aligned(); + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + /// Pushes a batch without metadata (non-filtered joins) + /// + /// No deferred filtering needed. Either every join match is output (Inner), + /// or null-joined rows are handled separately. No need to track which input + /// row produced which output row. + fn push_batch_without_metadata(&mut self, batch: RecordBatch, _join_type: JoinType) { + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + fn clear(&mut self, schema: &SchemaRef, batch_size: usize) { + self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); self.batch_ids.clear(); self.filter_mask = BooleanBuilder::new(); self.row_indices = UInt64Builder::new(); + self.debug_assert_empty_consistency(); } } impl RecordBatchStream for SortMergeJoinStream { @@ -386,6 +574,21 @@ fn last_index_for_row( batch_ids: &[usize], indices_len: usize, ) -> bool { + debug_assert_eq!( + indices.len(), + indices_len, + "indices.len() should match indices_len parameter" + ); + debug_assert_eq!( + batch_ids.len(), + indices_len, + "batch_ids.len() should match indices_len" + ); + debug_assert!( + row_index < indices_len, + "row_index {row_index} should be < indices_len {indices_len}", + ); + row_index == indices_len - 1 || batch_ids[row_index] != batch_ids[row_index + 1] || indices.value(row_index) != indices.value(row_index + 1) @@ -575,51 +778,12 @@ impl Stream for SortMergeJoinStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftMark - | JoinType::Right - | JoinType::RightSemi - | JoinType::RightMark - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Full - ) - { - self.freeze_all()?; - - // If join is filtered and there is joined tuples waiting - // to be filtered - if !self - .staging_output_record_batches - .batches - .is_empty() - { - // Apply filter on joined tuples and get filtered batch - let out_filtered_batch = - self.filter_joined_batch()?; - - // Append filtered batch to the output buffer - self.output = concat_batches( - &self.schema(), - [&self.output, &out_filtered_batch], - )?; - - // Send to output if the output buffer surpassed the `batch_size` - if self.output.num_rows() >= self.batch_size { - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty( - out_filtered_batch.schema(), - ), - ); - return Poll::Ready(Some(Ok( - record_batch, - ))); + if self.needs_deferred_filtering() { + match self.process_filtered_batches()? { + Poll::Ready(Some(batch)) => { + return Poll::Ready(Some(Ok(batch))); } + Poll::Ready(None) | Poll::Pending => {} } } @@ -669,78 +833,93 @@ impl Stream for SortMergeJoinStream { SortMergeJoinState::JoinOutput => { self.join_partial()?; - if self.output_size < self.batch_size { + if self.num_unfrozen_pairs() < self.batch_size { if self.buffered_data.scanning_finished() { self.buffered_data.scanning_reset(); self.state = SortMergeJoinState::Init; } } else { self.freeze_all()?; - if !self.staging_output_record_batches.batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - // For non-filtered join output whenever the target output batch size - // is hit. For filtered join its needed to output on later phase - // because target output batch size can be hit in the middle of - // filtering causing the filtering to be incomplete and causing - // correctness issues - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::RightMark - | JoinType::Full - ) - { - continue; - } + // Verify metadata alignment before checking if we have batches to output + self.joined_record_batches.debug_assert_metadata_aligned(); + + // For filtered joins, skip output and let Init state handle it + if self.needs_deferred_filtering() { + continue; + } + + // For non-filtered joins, only output if we have a completed batch + // (opportunistic output when target batch size is reached) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); return Poll::Ready(Some(Ok(record_batch))); } - return Poll::Pending; + // Otherwise keep buffering (don't output yet) } } SortMergeJoinState::Exhausted => { self.freeze_all()?; - // if there is still something not processed - if !self.staging_output_record_batches.batches.is_empty() { - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Full - | JoinType::LeftMark - | JoinType::RightMark - ) - { - let record_batch = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok(record_batch))); - } else { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); - } - } else if self.output.num_rows() > 0 { - // if processed but still not outputted because it didn't hit batch size before - let schema = self.output.schema(); - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty(schema), - ); + // Verify metadata alignment before final output + self.joined_record_batches.debug_assert_metadata_aligned(); + + // For filtered joins, must concat and filter ALL data at once + if self.needs_deferred_filtering() + && !self.joined_record_batches.joined_batches.is_empty() + { + let record_batch = self.filter_joined_batch()?; + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); return Poll::Ready(Some(Ok(record_batch))); - } else { - return Poll::Ready(None); } + + // For non-filtered joins, finish buffered data first + if !self.joined_record_batches.joined_batches.is_empty() { + self.joined_record_batches + .joined_batches + .finish_buffered_batch()?; + } + + // Output one completed batch at a time (stay in Exhausted until empty) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + + // Finally check self.output BatchCoalescer (used by filtered joins) + return if !self.output.is_empty() { + self.output.finish_buffered_batch()?; + let record_batch = self + .output + .next_completed_batch() + .expect("Failed to get last batch"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + Poll::Ready(Some(Ok(record_batch))) + } else { + Poll::Ready(None) + }; } } } @@ -793,14 +972,15 @@ impl SortMergeJoinStream { on_streamed, on_buffered, filter, - staging_output_record_batches: JoinedRecordBatches { - batches: vec![], + joined_record_batches: JoinedRecordBatches { + joined_batches: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }, - output: RecordBatch::new_empty(schema), - output_size: 0, + output: BatchCoalescer::new(schema, batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), batch_size, join_type, join_metrics, @@ -811,6 +991,59 @@ impl SortMergeJoinStream { }) } + /// Number of unfrozen output pairs (used to decide when to freeze + output) + fn num_unfrozen_pairs(&self) -> usize { + self.streamed_batch.num_output_rows() + } + + /// Returns true if this join needs deferred filtering + /// + /// Deferred filtering is needed when a filter exists and the join type requires + /// ensuring each input row produces at least one output row (or exactly one for semi). + fn needs_deferred_filtering(&self) -> bool { + self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightMark + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + ) + } + + /// Process accumulated batches for filtered joins + /// + /// Freezes unfrozen pairs, applies deferred filtering, and outputs if ready. + /// Returns Poll::Ready with a batch if one is available, otherwise Poll::Pending. + fn process_filtered_batches(&mut self) -> Poll>> { + self.freeze_all()?; + + self.joined_record_batches.debug_assert_metadata_aligned(); + + if !self.joined_record_batches.joined_batches.is_empty() { + let out_filtered_batch = self.filter_joined_batch()?; + self.output + .push_batch(out_filtered_batch) + .expect("Failed to push output batch"); + + if self.output.has_completed_batch() { + let record_batch = self + .output + .next_completed_batch() + .expect("Failed to get output batch"); + (&record_batch).record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + } + + Poll::Pending + } + /// Poll next streamed row fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { loop { @@ -1110,14 +1343,18 @@ impl SortMergeJoinStream { if join_buffered { // joining streamed/nulls and buffered while !self.buffered_data.scanning_finished() - && self.output_size < self.batch_size + && self.num_unfrozen_pairs() < self.batch_size { let scanning_idx = self.buffered_data.scanning_idx(); if join_streamed { // Join streamed row and buffered row + // Pass batch_size and num_unfrozen_pairs to compute capacity only when + // creating a new chunk (when buffered_batch_idx changes), not on every iteration. self.streamed_batch.append_output_pair( Some(self.buffered_data.scanning_batch_idx), Some(scanning_idx), + self.batch_size, + self.num_unfrozen_pairs(), ); } else { // Join nulls and buffered row for FULL join @@ -1126,7 +1363,6 @@ impl SortMergeJoinStream { .null_joined .push(scanning_idx); } - self.output_size += 1; self.buffered_data.scanning_advance(); if self.buffered_data.scanning_finished() { @@ -1144,9 +1380,14 @@ impl SortMergeJoinStream { // For Mark join we store a dummy id to indicate the row has a match let scanning_idx = mark_row_as_match.then_some(0); - self.streamed_batch - .append_output_pair(scanning_batch_idx, scanning_idx); - self.output_size += 1; + // Pass batch_size=1 and num_unfrozen_pairs=0 to get capacity of 1, + // since we only append a single null-joined pair here (not in a loop). + self.streamed_batch.append_output_pair( + scanning_batch_idx, + scanning_idx, + 1, + 0, + ); self.buffered_data.scanning_finish(); self.streamed_joined = true; } @@ -1156,6 +1397,10 @@ impl SortMergeJoinStream { fn freeze_all(&mut self) -> Result<()> { self.freeze_buffered(self.buffered_data.batches.len())?; self.freeze_streamed()?; + + // After freezing, metadata should be aligned + self.joined_record_batches.debug_assert_metadata_aligned(); + Ok(()) } @@ -1167,6 +1412,10 @@ impl SortMergeJoinStream { self.freeze_streamed()?; // Only freeze and produce the first batch in buffered_data as the batch is fully processed self.freeze_buffered(1)?; + + // After freezing, metadata should be aligned + self.joined_record_batches.debug_assert_metadata_aligned(); + Ok(()) } @@ -1189,21 +1438,8 @@ impl SortMergeJoinStream { &buffered_indices, buffered_batch, )? { - let num_rows = record_batch.num_rows(); - self.staging_output_record_batches - .filter_mask - .append_nulls(num_rows); - self.staging_output_record_batches - .row_indices - .append_nulls(num_rows); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() + num_rows, - 0, - ); - - self.staging_output_record_batches - .batches - .push(record_batch); + self.joined_record_batches + .push_batch_with_null_metadata(record_batch, self.join_type); } buffered_batch.null_joined.clear(); } @@ -1235,21 +1471,8 @@ impl SortMergeJoinStream { &buffered_indices, buffered_batch, )? { - let num_rows = record_batch.num_rows(); - - self.staging_output_record_batches - .filter_mask - .append_nulls(num_rows); - self.staging_output_record_batches - .row_indices - .append_nulls(num_rows); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() + num_rows, - 0, - ); - self.staging_output_record_batches - .batches - .push(record_batch); + self.joined_record_batches + .push_batch_with_null_metadata(record_batch, self.join_type); } buffered_batch.join_filter_not_matched_map.clear(); @@ -1378,7 +1601,9 @@ impl SortMergeJoinStream { }; // Push the filtered batch which contains rows passing join filter to the output - if matches!( + // For outer/semi/anti/mark joins with deferred filtering, push the unfiltered batch with metadata + // For INNER joins, filter immediately and push without metadata + let needs_deferred_filtering = matches!( self.join_type, JoinType::Left | JoinType::LeftSemi @@ -1389,32 +1614,29 @@ impl SortMergeJoinStream { | JoinType::LeftMark | JoinType::RightMark | JoinType::Full - ) { - self.staging_output_record_batches - .batches - .push(output_batch); - } else { - let filtered_batch = filter_record_batch(&output_batch, &mask)?; - self.staging_output_record_batches - .batches - .push(filtered_batch); - } + ); - if !matches!(self.join_type, JoinType::Full) { - self.staging_output_record_batches.filter_mask.extend(&mask); + if needs_deferred_filtering { + // Outer/semi/anti/mark joins: push unfiltered batch with metadata for deferred filtering + let mask_to_use = if !matches!(self.join_type, JoinType::Full) { + &mask + } else { + pre_mask + }; + + self.joined_record_batches.push_batch_with_filter_metadata( + output_batch, + &left_indices, + mask_to_use, + self.streamed_batch_counter.load(Relaxed), + self.join_type, + ); } else { - self.staging_output_record_batches - .filter_mask - .extend(pre_mask); + // INNER joins: filter immediately and push without metadata + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.joined_record_batches + .push_batch_without_metadata(filtered_batch, self.join_type); } - self.staging_output_record_batches - .row_indices - .extend(&left_indices); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() - + left_indices.len(), - self.streamed_batch_counter.load(Relaxed), - ); // For outer joins, we need to push the null joined rows to the output if // all joined rows are failed on the join filter. @@ -1443,15 +1665,10 @@ impl SortMergeJoinStream { ); } } - } else { - self.staging_output_record_batches - .batches - .push(output_batch); } } else { - self.staging_output_record_batches - .batches - .push(output_batch); + self.joined_record_batches + .push_batch_without_metadata(output_batch, self.join_type); } } @@ -1460,47 +1677,13 @@ impl SortMergeJoinStream { Ok(()) } - fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; - (&record_batch).record_output(&self.join_metrics.baseline_metrics()); - // If join filter exists, `self.output_size` is not accurate as we don't know the exact - // number of rows in the output record batch. If streamed row joined with buffered rows, - // once join filter is applied, the number of output rows may be more than 1. - // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened - // when the join filter is applied and all rows are filtered out. - if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { - self.output_size = 0; - } else { - self.output_size -= record_batch.num_rows(); - } - - if !(self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::RightMark - | JoinType::Full - )) - { - self.staging_output_record_batches.batches.clear(); - } - - Ok(record_batch) - } - fn filter_joined_batch(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; - let mut out_indices = self.staging_output_record_batches.row_indices.finish(); - let mut out_mask = self.staging_output_record_batches.filter_mask.finish(); - let mut batch_ids = &self.staging_output_record_batches.batch_ids; + // Metadata should be aligned before processing + self.joined_record_batches.debug_assert_metadata_aligned(); + + let record_batch = self.joined_record_batches.concat_batches(&self.schema)?; + let (mut out_indices, mut out_mask, mut batch_ids) = + self.joined_record_batches.finish_metadata(); let default_batch_ids = vec![0; record_batch.num_rows()]; // If only nulls come in and indices sizes doesn't match with expected record batch count @@ -1514,11 +1697,41 @@ impl SortMergeJoinStream { batch_ids = &default_batch_ids; } + // After potential reconstruction, metadata should align with batch row count + debug_assert_eq!( + out_indices.len(), + record_batch.num_rows(), + "out_indices length should match record_batch row count" + ); + debug_assert_eq!( + out_mask.len(), + record_batch.num_rows(), + "out_mask length should match record_batch row count (unless empty)" + ); + debug_assert_eq!( + batch_ids.len(), + record_batch.num_rows(), + "batch_ids length should match record_batch row count" + ); + if out_mask.is_empty() { - self.staging_output_record_batches.batches.clear(); + self.joined_record_batches + .clear_batches(&self.schema, self.batch_size); return Ok(record_batch); } + // Validate inputs to get_corrected_filter_mask + debug_assert_eq!( + out_indices.len(), + out_mask.len(), + "out_indices and out_mask must have same length for get_corrected_filter_mask" + ); + debug_assert_eq!( + batch_ids.len(), + out_mask.len(), + "batch_ids and out_mask must have same length for get_corrected_filter_mask" + ); + let maybe_corrected_mask = get_corrected_filter_mask( self.join_type, &out_indices, @@ -1541,6 +1754,15 @@ impl SortMergeJoinStream { record_batch: &RecordBatch, corrected_mask: &BooleanArray, ) -> Result { + // Corrected mask should have length matching or exceeding record_batch rows + // (for outer joins it may be longer to include null-joined rows) + debug_assert!( + corrected_mask.len() >= record_batch.num_rows(), + "corrected_mask length ({}) should be >= record_batch rows ({})", + corrected_mask.len(), + record_batch.num_rows() + ); + let mut filtered_record_batch = filter_record_batch(record_batch, corrected_mask)?; let left_columns_length = self.streamed_schema.fields.len(); @@ -1666,7 +1888,8 @@ impl SortMergeJoinStream { )?; } - self.staging_output_record_batches.clear(); + self.joined_record_batches + .clear(&self.schema, self.batch_size); Ok(filtered_record_batch) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 2e4725995b471..47a85b9b5c6ea 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -31,7 +31,7 @@ use arrow::array::{ BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, Int32Array, RecordBatch, UInt64Array, }; -use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; +use arrow::compute::{filter_record_batch, BatchCoalescer, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::JoinType::*; @@ -2374,14 +2374,14 @@ fn build_joined_record_batches() -> Result { ])); let mut batches = JoinedRecordBatches { - batches: vec![], + joined_batches: BatchCoalescer::new(Arc::clone(&schema), 8192), filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }; // Insert already prejoined non-filtered rows - batches.batches.push(RecordBatch::try_new( + batches.joined_batches.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2389,9 +2389,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![11, 9])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.joined_batches.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -2399,9 +2399,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1])), Arc::new(Int32Array::from(vec![12])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.joined_batches.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2409,9 +2409,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![11, 13])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.joined_batches.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -2419,9 +2419,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1])), Arc::new(Int32Array::from(vec![12])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.joined_batches.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2429,7 +2429,7 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![12, 11])), ], - )?); + )?)?; let streamed_indices = vec![0, 0]; batches.batch_ids.extend(vec![0; streamed_indices.len()]); @@ -2479,9 +2479,9 @@ fn build_joined_record_batches() -> Result { #[tokio::test] async fn test_left_outer_join_filtered_mask() -> Result<()> { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + let schema = joined_batches.joined_batches.schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + let output = joined_batches.concat_batches(&schema)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish(); @@ -2686,9 +2686,9 @@ async fn test_left_outer_join_filtered_mask() -> Result<()> { async fn test_semi_join_filtered_mask() -> Result<()> { for join_type in [LeftSemi, RightSemi] { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + let schema = joined_batches.joined_batches.schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + let output = joined_batches.concat_batches(&schema)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish(); @@ -2861,9 +2861,9 @@ async fn test_semi_join_filtered_mask() -> Result<()> { async fn test_anti_join_filtered_mask() -> Result<()> { for join_type in [LeftAnti, RightAnti] { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + let schema = joined_batches.joined_batches.schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + let output = joined_batches.concat_batches(&schema)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish();