diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index 4d3cc23696a..1f41f0db410 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -24,7 +24,7 @@ use arrow::{ use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator}; use datafusion::{prelude::*, scalar::ScalarValue}; -use std::{cell::RefCell, rc::Rc, sync::Arc}; +use std::sync::Arc; // create local execution context with an in-memory table fn create_context() -> Result { @@ -138,7 +138,7 @@ async fn main() -> Result<()> { // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|| Ok(Rc::new(RefCell::new(GeometricMean::new())))), + Arc::new(|| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index eabc779e49d..8df18c2ccc6 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -537,8 +537,8 @@ mod tests { ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray, }; use arrow::compute::add; + use std::fs::File; use std::thread::{self, JoinHandle}; - use std::{cell::RefCell, fs::File, rc::Rc}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; use test::*; @@ -1371,11 +1371,7 @@ mod tests { "MY_AVG", DataType::Float64, Arc::new(DataType::Float64), - Arc::new(|| { - Ok(Rc::new(RefCell::new(AvgAccumulator::try_new( - &DataType::Float64, - )?))) - }), + Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs index 40bb562b0e4..d417c41855d 100644 --- a/rust/datafusion/src/physical_plan/aggregates.rs +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -36,11 +36,11 @@ use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Schema}; use expressions::{avg_return_type, sum_return_type}; -use std::{cell::RefCell, fmt, rc::Rc, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = - Arc Result>> + Send + Sync>; + Arc Result> + Send + Sync>; /// This signature corresponds to which types an aggregator serializes /// its state, given its return datatype. diff --git a/rust/datafusion/src/physical_plan/distinct_expressions.rs b/rust/datafusion/src/physical_plan/distinct_expressions.rs index 2d2ab627d44..cc771078609 100644 --- a/rust/datafusion/src/physical_plan/distinct_expressions.rs +++ b/rust/datafusion/src/physical_plan/distinct_expressions.rs @@ -17,11 +17,9 @@ //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` -use std::cell::RefCell; use std::convert::TryFrom; use std::fmt::Debug; use std::hash::Hash; -use std::rc::Rc; use std::sync::Arc; use arrow::datatypes::{DataType, Field}; @@ -93,12 +91,12 @@ impl AggregateExpr for DistinctCount { self.exprs.clone() } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(DistinctCountAccumulator { + fn create_accumulator(&self) -> Result> { + Ok(Box::new(DistinctCountAccumulator { values: FnvHashSet::default(), data_types: self.input_data_types.clone(), count_data_type: self.data_type.clone(), - }))) + })) } } @@ -282,8 +280,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; accum.update_batch(arrays)?; Ok((accum.state()?, accum.evaluate()?)) @@ -300,8 +297,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; for row in rows.iter() { accum.update(row)? @@ -324,8 +320,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; accum.merge_batch(arrays)?; Ok((accum.state()?, accum.evaluate()?)) diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 4c9029e7195..1f5dafdc19d 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -17,10 +17,9 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use std::convert::TryFrom; use std::fmt; -use std::rc::Rc; use std::sync::Arc; -use std::{cell::RefCell, convert::TryFrom}; use crate::error::{ExecutionError, Result}; use crate::logical_plan::Operator; @@ -162,10 +161,8 @@ impl AggregateExpr for Sum { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(SumAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) } } @@ -391,11 +388,11 @@ impl AggregateExpr for Avg { ]) } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(AvgAccumulator::try_new( + fn create_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new( // avg is f64 &DataType::Float64, - )?))) + )?)) } fn expressions(&self) -> Vec> { @@ -521,10 +518,8 @@ impl AggregateExpr for Max { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(MaxAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) } } @@ -774,10 +769,8 @@ impl AggregateExpr for Min { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(MinAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) } } @@ -869,8 +862,8 @@ impl AggregateExpr for Count { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(CountAccumulator::new()))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CountAccumulator::new())) } } @@ -2476,13 +2469,12 @@ mod tests { batch: &RecordBatch, agg: Arc, ) -> Result { - let accum = agg.create_accumulator()?; + let mut accum = agg.create_accumulator()?; let expr = agg.expressions(); let values = expr .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; - let mut accum = accum.borrow_mut(); accum.update_batch(&values)?; accum.evaluate() } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 5f4fe9876b7..53b74c2db40 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -18,8 +18,6 @@ //! Defines the execution plan for the hash aggregate operation use std::any::Any; -use std::cell::RefCell; -use std::rc::Rc; use std::sync::Arc; use crate::error::{ExecutionError, Result}; @@ -278,9 +276,8 @@ fn group_aggregate_batch( .map(|(_, (accumulator_set, indices))| { // 2.2 accumulator_set - .iter() - .zip(&aggr_input_values) .into_iter() + .zip(&aggr_input_values) .map(|(accumulator, aggr_array)| { ( accumulator, @@ -300,12 +297,10 @@ fn group_aggregate_batch( }) // 2.4 .map(|(accumulator, values)| match mode { - AggregateMode::Partial => { - accumulator.borrow_mut().update_batch(&values) - } + AggregateMode::Partial => accumulator.update_batch(&values), AggregateMode::Final => { // note: the aggregation here is over states, not values, thus the merge - accumulator.borrow_mut().merge_batch(&values) + accumulator.merge_batch(&values) } }) .collect::>() @@ -335,7 +330,7 @@ impl GroupedHashAggregateIterator { } } -type AccumulatorSet = Vec>>; +type AccumulatorSet = Vec>; impl Iterator for GroupedHashAggregateIterator { type Item = ArrowResult; @@ -490,7 +485,7 @@ impl HashAggregateIterator { fn aggregate_batch( mode: &AggregateMode, batch: &RecordBatch, - accumulators: &AccumulatorSet, + accumulators: &mut AccumulatorSet, expressions: &Vec>>, ) -> Result<()> { // 1.1 iterate accumulators and respective expressions together @@ -499,7 +494,7 @@ fn aggregate_batch( // 1.1 accumulators - .iter() + .into_iter() .zip(expressions) .map(|(accum, expr)| { // 1.2 @@ -510,8 +505,8 @@ fn aggregate_batch( // 1.3 match mode { - AggregateMode::Partial => accum.borrow_mut().update_batch(values), - AggregateMode::Final => accum.borrow_mut().merge_batch(values), + AggregateMode::Partial => accum.update_batch(values), + AggregateMode::Final => accum.merge_batch(values), } }) .collect::>() @@ -528,7 +523,7 @@ impl Iterator for HashAggregateIterator { // return single batch self.finished = true; - let accumulators = match create_accumulators(&self.aggr_expr) { + let mut accumulators = match create_accumulators(&self.aggr_expr) { Ok(e) => e, Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))), }; @@ -547,7 +542,7 @@ impl Iterator for HashAggregateIterator { .as_mut() .into_iter() .map(|batch| { - aggregate_batch(&mode, &batch?, &accumulators, &expressions) + aggregate_batch(&mode, &batch?, &mut accumulators, &expressions) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>() @@ -655,7 +650,7 @@ fn finalize_aggregation( // build the vector of states let a = accumulators .iter() - .map(|accumulator| accumulator.borrow_mut().state()) + .map(|accumulator| accumulator.state()) .map(|value| { value.and_then(|e| { Ok(e.iter().map(|v| v.to_array()).collect::>()) @@ -668,12 +663,7 @@ fn finalize_aggregation( // merge the state to the final value accumulators .iter() - .map(|accumulator| { - accumulator - .borrow_mut() - .evaluate() - .and_then(|v| Ok(v.to_array())) - }) + .map(|accumulator| accumulator.evaluate().and_then(|v| Ok(v.to_array()))) .collect::>>() } } diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index ac33c67f6ac..1d6c46afe09 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -18,9 +18,7 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. use std::any::Any; -use std::cell::RefCell; use std::fmt::{Debug, Display}; -use std::rc::Rc; use std::sync::Arc; use crate::execution::context::ExecutionContextState; @@ -122,7 +120,7 @@ pub trait AggregateExpr: Send + Sync + Debug { /// the accumulator used to accumulate values from the expressions. /// the accumulator expects the same number of arguments as `expressions` and must /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>>; + fn create_accumulator(&self) -> Result>; /// the fields that encapsulate the Accumulator's state /// the number of fields here equals the number of states that the accumulator contains diff --git a/rust/datafusion/src/physical_plan/udaf.rs b/rust/datafusion/src/physical_plan/udaf.rs index 933fd237c65..db86e1447ab 100644 --- a/rust/datafusion/src/physical_plan/udaf.rs +++ b/rust/datafusion/src/physical_plan/udaf.rs @@ -18,7 +18,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. use fmt::{Debug, Formatter}; -use std::{cell::RefCell, fmt, rc::Rc}; +use std::fmt; use arrow::{ datatypes::Field, @@ -150,7 +150,7 @@ impl AggregateExpr for AggregateFunctionExpr { Ok(Field::new(&self.name, self.data_type.clone(), true)) } - fn create_accumulator(&self) -> Result>> { + fn create_accumulator(&self) -> Result> { (self.fun.accumulator)() } }