diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index abec10d7fc6..25630a9ec8e 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -699,7 +699,10 @@ macro_rules! hash_array { } /// Creates hash values for every element in the row based on the values in the columns -fn create_hashes(arrays: &[ArrayRef], random_state: &RandomState) -> Result> { +pub fn create_hashes( + arrays: &[ArrayRef], + random_state: &RandomState, +) -> Result> { let rows = arrays[0].len(); let mut hashes = vec![0; rows]; diff --git a/rust/datafusion/src/physical_plan/repartition.rs b/rust/datafusion/src/physical_plan/repartition.rs index edabfde27c4..63854988fa5 100644 --- a/rust/datafusion/src/physical_plan/repartition.rs +++ b/rust/datafusion/src/physical_plan/repartition.rs @@ -18,18 +18,18 @@ //! The repartition operator maps N input partitions to M output partitions based on a //! partitioning scheme. -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ExecutionPlan, Partitioning}; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +use arrow::{array::Array, error::Result as ArrowResult}; +use arrow::{compute::take, datatypes::SchemaRef}; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{hash_join::create_hashes, RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; use crossbeam::channel::{unbounded, Receiver, Sender}; @@ -120,8 +120,11 @@ impl ExecutionPlan for RepartitionExec { let (sender, receiver) = unbounded::>>(); channels.push((sender, receiver)); } + let random = ahash::RandomState::new(); + // launch one async task per *input* partition for i in 0..num_input_partitions { + let random_state = random.clone(); let input = self.input.clone(); let mut channels = channels.clone(); let partitioning = self.partitioning.clone(); @@ -129,7 +132,7 @@ impl ExecutionPlan for RepartitionExec { let mut stream = input.execute(i).await?; let mut counter = 0; while let Some(result) = stream.next().await { - match partitioning { + match &partitioning { Partitioning::RoundRobinBatch(_) => { let output_partition = counter % num_output_partitions; let tx = &mut channels[output_partition].0; @@ -137,6 +140,52 @@ impl ExecutionPlan for RepartitionExec { DataFusionError::Execution(e.to_string()) })?; } + Partitioning::Hash(exprs, _) => { + let input_batch = result?; + let arrays = exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&input_batch)? + .into_array(input_batch.num_rows())) + }) + .collect::>>()?; + // Hash arrays and compute buckets based on number of partitions + let hashes = create_hashes(&arrays, &random_state)?; + let mut indices = vec![vec![]; num_output_partitions]; + for (index, hash) in hashes.iter().enumerate() { + indices + [(*hash % num_output_partitions as u64) as usize] + .push(index as u64) + } + for (num_output_partition, partition_indices) in + indices.into_iter().enumerate() + { + let indices = partition_indices.into(); + // Produce batches based on indices + let columns = input_batch + .columns() + .iter() + .map(|c| { + take(c.as_ref(), &indices, None).map_err( + |e| { + DataFusionError::Execution( + e.to_string(), + ) + }, + ) + }) + .collect::>>>()?; + let output_batch = RecordBatch::try_new( + input_batch.schema(), + columns, + ); + let tx = &mut channels[num_output_partition].0; + tx.send(Some(output_batch)).map_err(|e| { + DataFusionError::Execution(e.to_string()) + })?; + } + } other => { // this should be unreachable as long as the validation logic // in the constructor is kept up-to-date @@ -181,17 +230,11 @@ impl RepartitionExec { input: Arc, partitioning: Partitioning, ) -> Result { - match &partitioning { - Partitioning::RoundRobinBatch(_) => Ok(RepartitionExec { - input, - partitioning, - channels: Arc::new(Mutex::new(vec![])), - }), - other => Err(DataFusionError::NotImplemented(format!( - "Partitioning scheme not supported yet: {:?}", - other - ))), - } + Ok(RepartitionExec { + input, + partitioning, + channels: Arc::new(Mutex::new(vec![])), + }) } } @@ -305,6 +348,33 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn many_to_many_hash_partition() -> Result<()> { + // define input partitions + let schema = test_schema(); + let partition = create_vec_batches(&schema, 50); + let partitions = vec![partition.clone(), partition.clone(), partition.clone()]; + + let output_partitions = repartition( + &schema, + partitions, + Partitioning::Hash( + vec![Arc::new(crate::physical_plan::expressions::Column::new( + &"c0", + ))], + 8, + ), + ) + .await?; + + let total_rows: usize = output_partitions.iter().map(|x| x.len()).sum(); + + assert_eq!(8, output_partitions.len()); + assert_eq!(total_rows, 8 * 50 * 3); + + Ok(()) + } + fn test_schema() -> Arc { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) }