Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,10 @@ macro_rules! hash_array {
}

/// Creates hash values for every element in the row based on the values in the columns
fn create_hashes(arrays: &[ArrayRef], random_state: &RandomState) -> Result<Vec<u64>> {
pub fn create_hashes(
arrays: &[ArrayRef],
random_state: &RandomState,
) -> Result<Vec<u64>> {
let rows = arrays[0].len();
let mut hashes = vec![0; rows];

Expand Down
102 changes: 86 additions & 16 deletions rust/datafusion/src/physical_plan/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
//! The repartition operator maps N input partitions to M output partitions based on a
//! partitioning scheme.

use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{ExecutionPlan, Partitioning};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::Array, error::Result as ArrowResult};
use arrow::{compute::take, datatypes::SchemaRef};

use super::{RecordBatchStream, SendableRecordBatchStream};
use super::{hash_join::create_hashes, RecordBatchStream, SendableRecordBatchStream};
use async_trait::async_trait;

use crossbeam::channel::{unbounded, Receiver, Sender};
Expand Down Expand Up @@ -120,23 +120,72 @@ impl ExecutionPlan for RepartitionExec {
let (sender, receiver) = unbounded::<Option<ArrowResult<RecordBatch>>>();
channels.push((sender, receiver));
}
let random = ahash::RandomState::new();

// launch one async task per *input* partition
for i in 0..num_input_partitions {
let random_state = random.clone();
let input = self.input.clone();
let mut channels = channels.clone();
let partitioning = self.partitioning.clone();
let join_handle: JoinHandle<Result<()>> = tokio::spawn(async move {
let mut stream = input.execute(i).await?;
let mut counter = 0;
while let Some(result) = stream.next().await {
match partitioning {
match &partitioning {
Partitioning::RoundRobinBatch(_) => {
let output_partition = counter % num_output_partitions;
let tx = &mut channels[output_partition].0;
tx.send(Some(result)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
}
Partitioning::Hash(exprs, _) => {
let input_batch = result?;
let arrays = exprs
.iter()
.map(|expr| {
Ok(expr
.evaluate(&input_batch)?
.into_array(input_batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
// Hash arrays and compute buckets based on number of partitions
let hashes = create_hashes(&arrays, &random_state)?;
let mut indices = vec![vec![]; num_output_partitions];
for (index, hash) in hashes.iter().enumerate() {
indices
[(*hash % num_output_partitions as u64) as usize]
.push(index as u64)
}
for (num_output_partition, partition_indices) in
indices.into_iter().enumerate()
{
let indices = partition_indices.into();
// Produce batches based on indices
let columns = input_batch
.columns()
.iter()
.map(|c| {
take(c.as_ref(), &indices, None).map_err(
|e| {
DataFusionError::Execution(
e.to_string(),
)
},
)
})
.collect::<Result<Vec<Arc<dyn Array>>>>()?;
let output_batch = RecordBatch::try_new(
input_batch.schema(),
columns,
);
let tx = &mut channels[num_output_partition].0;
tx.send(Some(output_batch)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
}
}
other => {
// this should be unreachable as long as the validation logic
// in the constructor is kept up-to-date
Expand Down Expand Up @@ -181,17 +230,11 @@ impl RepartitionExec {
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
) -> Result<Self> {
match &partitioning {
Partitioning::RoundRobinBatch(_) => Ok(RepartitionExec {
input,
partitioning,
channels: Arc::new(Mutex::new(vec![])),
}),
other => Err(DataFusionError::NotImplemented(format!(
"Partitioning scheme not supported yet: {:?}",
other
))),
}
Ok(RepartitionExec {
input,
partitioning,
channels: Arc::new(Mutex::new(vec![])),
})
}
}

Expand Down Expand Up @@ -305,6 +348,33 @@ mod tests {
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn many_to_many_hash_partition() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];

let output_partitions = repartition(
&schema,
partitions,
Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
&"c0",
))],
8,
),
)
.await?;

let total_rows: usize = output_partitions.iter().map(|x| x.len()).sum();

assert_eq!(8, output_partitions.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense here also to assert on the distribution of rows (e.g. ensure that each batch has ~ 50*3 rows?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but not sure how to do that currently, as it depends on random state (it could happen that all of them end up on same hash / partition in a very rare case).

assert_eq!(total_rows, 8 * 50 * 3);

Ok(())
}

fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
Expand Down