Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/datafusion/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::{

use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
use datafusion::{prelude::*, scalar::ScalarValue};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use std::sync::Arc;

// create local execution context with an in-memory table
fn create_context() -> Result<ExecutionContext> {
Expand Down Expand Up @@ -138,7 +138,7 @@ async fn main() -> Result<()> {
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Float64),
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|| Ok(Rc::new(RefCell::new(GeometricMean::new())))),
Arc::new(|| Ok(Box::new(GeometricMean::new()))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
Expand Down
8 changes: 2 additions & 6 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,8 @@ mod tests {
ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray,
};
use arrow::compute::add;
use std::fs::File;
use std::thread::{self, JoinHandle};
use std::{cell::RefCell, fs::File, rc::Rc};
use std::{io::prelude::*, sync::Mutex};
use tempfile::TempDir;
use test::*;
Expand Down Expand Up @@ -1371,11 +1371,7 @@ mod tests {
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Arc::new(|| {
Ok(Rc::new(RefCell::new(AvgAccumulator::try_new(
&DataType::Float64,
)?)))
}),
Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema};
use expressions::{avg_return_type, sum_return_type};
use std::{cell::RefCell, fmt, rc::Rc, str::FromStr, sync::Arc};
use std::{fmt, str::FromStr, sync::Arc};

/// the implementation of an aggregate function
pub type AccumulatorFunctionImplementation =
Arc<dyn Fn() -> Result<Rc<RefCell<dyn Accumulator>>> + Send + Sync>;
Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;

/// This signature corresponds to which types an aggregator serializes
/// its state, given its return datatype.
Expand Down
17 changes: 6 additions & 11 deletions rust/datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`

use std::cell::RefCell;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::hash::Hash;
use std::rc::Rc;
use std::sync::Arc;

use arrow::datatypes::{DataType, Field};
Expand Down Expand Up @@ -93,12 +91,12 @@ impl AggregateExpr for DistinctCount {
self.exprs.clone()
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(DistinctCountAccumulator {
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: FnvHashSet::default(),
data_types: self.input_data_types.clone(),
count_data_type: self.data_type.clone(),
})))
}))
}
}

Expand Down Expand Up @@ -282,8 +280,7 @@ mod tests {
DataType::UInt64,
);

let accum = agg.create_accumulator()?;
let mut accum = accum.borrow_mut();
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;

Ok((accum.state()?, accum.evaluate()?))
Expand All @@ -300,8 +297,7 @@ mod tests {
DataType::UInt64,
);

let accum = agg.create_accumulator()?;
let mut accum = accum.borrow_mut();
let mut accum = agg.create_accumulator()?;

for row in rows.iter() {
accum.update(row)?
Expand All @@ -324,8 +320,7 @@ mod tests {
DataType::UInt64,
);

let accum = agg.create_accumulator()?;
let mut accum = accum.borrow_mut();
let mut accum = agg.create_accumulator()?;
accum.merge_batch(arrays)?;

Ok((accum.state()?, accum.evaluate()?))
Expand Down
34 changes: 13 additions & 21 deletions rust/datafusion/src/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

//! Defines physical expressions that can evaluated at runtime during query execution

use std::convert::TryFrom;
use std::fmt;
use std::rc::Rc;
use std::sync::Arc;
use std::{cell::RefCell, convert::TryFrom};

use crate::error::{ExecutionError, Result};
use crate::logical_plan::Operator;
Expand Down Expand Up @@ -162,10 +161,8 @@ impl AggregateExpr for Sum {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(SumAccumulator::try_new(
&self.data_type,
)?)))
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
}
}

Expand Down Expand Up @@ -391,11 +388,11 @@ impl AggregateExpr for Avg {
])
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(AvgAccumulator::try_new(
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64
&DataType::Float64,
)?)))
)?))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -521,10 +518,8 @@ impl AggregateExpr for Max {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(MaxAccumulator::try_new(
&self.data_type,
)?)))
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?))
}
}

Expand Down Expand Up @@ -774,10 +769,8 @@ impl AggregateExpr for Min {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(MinAccumulator::try_new(
&self.data_type,
)?)))
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
}
}

Expand Down Expand Up @@ -869,8 +862,8 @@ impl AggregateExpr for Count {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
Ok(Rc::new(RefCell::new(CountAccumulator::new())))
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CountAccumulator::new()))
}
}

Expand Down Expand Up @@ -2476,13 +2469,12 @@ mod tests {
batch: &RecordBatch,
agg: Arc<dyn AggregateExpr>,
) -> Result<ScalarValue> {
let accum = agg.create_accumulator()?;
let mut accum = agg.create_accumulator()?;
let expr = agg.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;
let mut accum = accum.borrow_mut();
accum.update_batch(&values)?;
accum.evaluate()
}
Expand Down
34 changes: 12 additions & 22 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
//! Defines the execution plan for the hash aggregate operation

use std::any::Any;
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;

use crate::error::{ExecutionError, Result};
Expand Down Expand Up @@ -278,9 +276,8 @@ fn group_aggregate_batch(
.map(|(_, (accumulator_set, indices))| {
// 2.2
accumulator_set
.iter()
.zip(&aggr_input_values)
.into_iter()
.zip(&aggr_input_values)
.map(|(accumulator, aggr_array)| {
(
accumulator,
Expand All @@ -300,12 +297,10 @@ fn group_aggregate_batch(
})
// 2.4
.map(|(accumulator, values)| match mode {
AggregateMode::Partial => {
accumulator.borrow_mut().update_batch(&values)
}
AggregateMode::Partial => accumulator.update_batch(&values),
AggregateMode::Final => {
// note: the aggregation here is over states, not values, thus the merge
accumulator.borrow_mut().merge_batch(&values)
accumulator.merge_batch(&values)
}
})
.collect::<Result<()>>()
Expand Down Expand Up @@ -335,7 +330,7 @@ impl GroupedHashAggregateIterator {
}
}

type AccumulatorSet = Vec<Rc<RefCell<dyn Accumulator>>>;
type AccumulatorSet = Vec<Box<dyn Accumulator>>;

impl Iterator for GroupedHashAggregateIterator {
type Item = ArrowResult<RecordBatch>;
Expand Down Expand Up @@ -490,7 +485,7 @@ impl HashAggregateIterator {
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &AccumulatorSet,
accumulators: &mut AccumulatorSet,
expressions: &Vec<Vec<Arc<dyn PhysicalExpr>>>,
) -> Result<()> {
// 1.1 iterate accumulators and respective expressions together
Expand All @@ -499,7 +494,7 @@ fn aggregate_batch(

// 1.1
accumulators
.iter()
.into_iter()
.zip(expressions)
.map(|(accum, expr)| {
// 1.2
Expand All @@ -510,8 +505,8 @@ fn aggregate_batch(

// 1.3
match mode {
AggregateMode::Partial => accum.borrow_mut().update_batch(values),
AggregateMode::Final => accum.borrow_mut().merge_batch(values),
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Final => accum.merge_batch(values),
}
})
.collect::<Result<()>>()
Expand All @@ -528,7 +523,7 @@ impl Iterator for HashAggregateIterator {
// return single batch
self.finished = true;

let accumulators = match create_accumulators(&self.aggr_expr) {
let mut accumulators = match create_accumulators(&self.aggr_expr) {
Ok(e) => e,
Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
};
Expand All @@ -547,7 +542,7 @@ impl Iterator for HashAggregateIterator {
.as_mut()
.into_iter()
.map(|batch| {
aggregate_batch(&mode, &batch?, &accumulators, &expressions)
aggregate_batch(&mode, &batch?, &mut accumulators, &expressions)
.map_err(ExecutionError::into_arrow_external_error)
})
.collect::<ArrowResult<()>>()
Expand Down Expand Up @@ -655,7 +650,7 @@ fn finalize_aggregation(
// build the vector of states
let a = accumulators
.iter()
.map(|accumulator| accumulator.borrow_mut().state())
.map(|accumulator| accumulator.state())
.map(|value| {
value.and_then(|e| {
Ok(e.iter().map(|v| v.to_array()).collect::<Vec<ArrayRef>>())
Expand All @@ -668,12 +663,7 @@ fn finalize_aggregation(
// merge the state to the final value
accumulators
.iter()
.map(|accumulator| {
accumulator
.borrow_mut()
.evaluate()
.and_then(|v| Ok(v.to_array()))
})
.map(|accumulator| accumulator.evaluate().and_then(|v| Ok(v.to_array())))
.collect::<Result<Vec<ArrayRef>>>()
}
}
Expand Down
4 changes: 1 addition & 3 deletions rust/datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
//! Traits for physical query plan, supporting parallel execution for partitioned relations.

use std::any::Any;
use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::rc::Rc;
use std::sync::Arc;

use crate::execution::context::ExecutionContextState;
Expand Down Expand Up @@ -122,7 +120,7 @@ pub trait AggregateExpr: Send + Sync + Debug {
/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>>;
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>>;

/// the fields that encapsulate the Accumulator's state
/// the number of fields here equals the number of states that the accumulator contains
Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/physical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! This module contains functions and structs supporting user-defined aggregate functions.

use fmt::{Debug, Formatter};
use std::{cell::RefCell, fmt, rc::Rc};
use std::fmt;

use arrow::{
datatypes::Field,
Expand Down Expand Up @@ -150,7 +150,7 @@ impl AggregateExpr for AggregateFunctionExpr {
Ok(Field::new(&self.name, self.data_type.clone(), true))
}

fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
(self.fun.accumulator)()
}
}