From f62797b95500b8f6c268a3ae9b31a709324d7a95 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Wed, 31 Jul 2024 19:41:17 +0800 Subject: [PATCH 1/7] Remove schema in args --- datafusion/expr/src/function.rs | 5 +---- datafusion/functions-aggregate/src/array_agg.rs | 2 +- datafusion/functions-aggregate/src/first_last.rs | 4 ++-- datafusion/functions-aggregate/src/nth_value.rs | 2 +- datafusion/functions-aggregate/src/stddev.rs | 2 -- datafusion/physical-expr-common/src/aggregate/mod.rs | 6 ------ 6 files changed, 5 insertions(+), 16 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index d8be2b434732..1f3f8ace4d17 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,7 +19,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{DFSchema, Result}; use std::sync::Arc; @@ -54,9 +54,6 @@ pub struct AccumulatorArgs<'a> { /// The return type of the aggregate function. pub data_type: &'a DataType, - /// The schema of the input arguments - pub schema: &'a Schema, - /// The schema of the input arguments pub dfschema: &'a DFSchema, diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 36c9d6a0d7c8..bb25de113525 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -136,7 +136,7 @@ impl AggregateUDFImpl for ArrayAgg { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.schema)) + .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 587767b8e356..7563506b3b65 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -124,7 +124,7 @@ impl AggregateUDFImpl for FirstValue { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.schema)) + .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) .collect::>>()?; // When requirement is empty, or it is signalled by outside caller that @@ -423,7 +423,7 @@ impl AggregateUDFImpl for LastValue { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.schema)) + .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) .collect::>>()?; let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index dc7c6c86f213..7c4b9a7f06c6 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -109,7 +109,7 @@ impl AggregateUDFImpl for NthValueAgg { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.schema)) + .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) .collect::>>()?; NthValueAccumulator::try_new( diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index df757ddc0422..1d2257d90133 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -328,7 +328,6 @@ mod tests { let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { data_type: &DataType::Float64, - schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], @@ -341,7 +340,6 @@ mod tests { let args2 = AccumulatorArgs { data_type: &DataType::Float64, - schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 665cdd708329..0707301b2557 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -220,7 +220,6 @@ impl AggregateExprBuilder { logical_args, data_type, name, - schema: Arc::unwrap_or_clone(schema), dfschema, sort_exprs, ordering_req, @@ -456,7 +455,6 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, - schema: Schema, dfschema: DFSchema, // The logical order by expressions sort_exprs: Vec, @@ -522,7 +520,6 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { data_type: &self.data_type, - schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -539,7 +536,6 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { data_type: &self.data_type, - schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -611,7 +607,6 @@ impl AggregateExpr for AggregateFunctionExpr { fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { data_type: &self.data_type, - schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -627,7 +622,6 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { data_type: &self.data_type, - schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, From 0faaa266d8a57a3e81ed6900e771582643773366 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 1 Aug 2024 18:31:46 +0800 Subject: [PATCH 2/7] improve AccumulatorArgs --- .../physical_plan/parquet/opener.rs | 2 +- .../physical_plan/parquet/statistics.rs | 1 + .../core/src/execution/session_state.rs | 2 +- .../aggregate_statistics.rs | 3 +- .../combine_partial_final_agg.rs | 5 +- .../src/physical_optimizer/limit_pushdown.rs | 2 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 3 +- .../src/expressions/column.rs | 4 +- datafusion/expr/src/expressions/mod.rs | 18 ++ datafusion/expr/src/function.rs | 6 +- datafusion/expr/src/lib.rs | 2 + .../src/physical_expr.rs | 6 +- datafusion/expr/src/tree_node.rs | 86 +++++++++- datafusion/expr/src/utils.rs | 141 ++++++++++++++++ .../src/approx_distinct.rs | 41 ++--- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 43 +++-- .../src/approx_percentile_cont_with_weight.rs | 5 +- .../functions-aggregate/src/array_agg.rs | 13 +- datafusion/functions-aggregate/src/average.rs | 19 ++- datafusion/functions-aggregate/src/count.rs | 3 +- datafusion/functions-aggregate/src/median.rs | 2 +- .../functions-aggregate/src/nth_value.rs | 31 ++-- datafusion/functions-aggregate/src/stddev.rs | 8 +- .../functions-aggregate/src/string_agg.rs | 33 ++-- .../physical-expr-common/src/aggregate/mod.rs | 35 ++-- .../src/expressions/cast.rs | 4 +- .../src/expressions/literal.rs | 3 +- .../src/expressions/mod.rs | 1 - datafusion/physical-expr-common/src/lib.rs | 2 - .../physical-expr-common/src/sort_expr.rs | 3 +- .../physical-expr-common/src/tree_node.rs | 105 ------------ datafusion/physical-expr-common/src/utils.rs | 159 +----------------- datafusion/physical-expr/benches/case_when.rs | 4 +- datafusion/physical-expr/benches/is_null.rs | 4 +- .../src/equivalence/properties.rs | 6 +- .../physical-expr/src/expressions/binary.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-expr/src/lib.rs | 4 +- datafusion/physical-expr/src/physical_expr.rs | 4 +- .../physical-plan/src/aggregates/mod.rs | 16 +- datafusion/physical-plan/src/union.rs | 2 +- datafusion/physical-plan/src/windows/mod.rs | 7 +- datafusion/proto/src/physical_plan/mod.rs | 6 +- .../tests/cases/roundtrip_physical_plan.rs | 22 +-- datafusion/substrait/src/serializer.rs | 1 - 47 files changed, 434 insertions(+), 441 deletions(-) rename datafusion/{physical-expr-common => expr}/src/expressions/column.rs (98%) create mode 100644 datafusion/expr/src/expressions/mod.rs rename datafusion/{physical-expr-common => expr}/src/physical_expr.rs (98%) delete mode 100644 datafusion/physical-expr-common/src/tree_node.rs diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 4edc0ac525de..9e83d4dfe274 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -30,7 +30,7 @@ use crate::datasource::schema_adapter::SchemaAdapterFactory; use crate::physical_optimizer::pruning::PruningPredicate; use arrow_schema::{ArrowError, SchemaRef}; use datafusion_common::{exec_err, Result}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{StreamExt, TryStreamExt}; use log::debug; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 11b8f5fc6c79..eec7c95fff94 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -517,6 +517,7 @@ macro_rules! make_data_page_stats_iterator { } } + #[allow(clippy::redundant_closure_call)] impl<'a, I> Iterator for $iterator_type<'a, I> where I: Iterator, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ccad0240fddb..fd9b5c786859 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -48,6 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; @@ -61,7 +62,6 @@ use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, }; use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index a8332d1d55e4..590f9dc8fde1 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -324,6 +324,7 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; @@ -421,7 +422,7 @@ pub(crate) mod tests { // Return appropriate expr depending if COUNT is for col or table (*) pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![self.column()]) - .schema(Arc::new(schema.clone())) + .dfschema(schema.clone().to_dfschema().unwrap()) .name(self.column_name()) .build() .unwrap() diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 6f3274820c8c..ab547b86f582 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -174,6 +174,7 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; @@ -279,7 +280,7 @@ mod tests { schema: &Schema, ) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) - .schema(Arc::new(schema.clone())) + .dfschema(schema.clone().to_dfschema().unwrap()) .name(name) .build() .unwrap() @@ -363,7 +364,7 @@ mod tests { let aggr_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("Sum(b)") .build() .unwrap(), diff --git a/datafusion/core/src/physical_optimizer/limit_pushdown.rs b/datafusion/core/src/physical_optimizer/limit_pushdown.rs index 4379a34a9426..ef641e40b78b 100644 --- a/datafusion/core/src/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/limit_pushdown.rs @@ -256,10 +256,10 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::expressions::column::col; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::Partitioning; - use datafusion_physical_expr_common::expressions::column::col; use datafusion_physical_expr_common::expressions::lit; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 6f286c9aeba1..31fa59af8c18 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -32,6 +32,7 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; @@ -106,7 +107,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let aggregate_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema().unwrap()) .name("sum1") .build() .unwrap(), diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/expr/src/expressions/column.rs similarity index 98% rename from datafusion/physical-expr-common/src/expressions/column.rs rename to datafusion/expr/src/expressions/column.rs index 5397599ea2dc..fa8e6188038a 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/expr/src/expressions/column.rs @@ -21,12 +21,12 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::ColumnarValue; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; @@ -89,7 +89,7 @@ impl PhysicalExpr for Column { /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { self.bounds_check(batch.schema().as_ref())?; - Ok(ColumnarValue::Array(batch.column(self.index).clone())) + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/expr/src/expressions/mod.rs b/datafusion/expr/src/expressions/mod.rs new file mode 100644 index 000000000000..d102422081dc --- /dev/null +++ b/datafusion/expr/src/expressions/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod column; diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1f3f8ace4d17..1ed793228ded 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,6 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. +use crate::physical_expr::PhysicalExpr; use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Field}; @@ -91,11 +92,8 @@ pub struct AccumulatorArgs<'a> { /// ``` pub is_distinct: bool, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], - /// The logical expression of arguments the aggregate function takes. - pub input_exprs: &'a [Expr], + pub input_exprs: &'a [Arc], } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0a5cf4653a22..3e02b0fdb3ed 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -46,10 +46,12 @@ pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; +pub mod expressions; pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod physical_expr; pub mod planner; pub mod registry; pub mod simplify; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/expr/src/physical_expr.rs similarity index 98% rename from datafusion/physical-expr-common/src/physical_expr.rs rename to datafusion/expr/src/physical_expr.rs index e62606a42e6f..35bf5455df67 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/expr/src/physical_expr.rs @@ -23,15 +23,15 @@ use std::sync::Arc; use crate::expressions::column::Column; use crate::utils::scatter; +use crate::interval_arithmetic::Interval; +use crate::sort_properties::ExprProperties; +use crate::ColumnarValue; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, not_impl_err, plan_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::ColumnarValue; /// See [create_physical_expr](https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html) /// for examples of creating `PhysicalExpr` from `Expr` diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a97b9f010f79..813257122b65 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,15 +17,20 @@ //! Tree node implementation for logical expr +use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; + use crate::expr::{ AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; +use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + ConcreteTreeNode, DynTreeNode, Transformed, TreeNode, TreeNodeIterator, + TreeNodeRecursion, }; use datafusion_common::{map_until_stop_and_collect, Result}; @@ -401,3 +406,82 @@ fn transform_vec Result>>( ) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +impl DynTreeNode for dyn PhysicalExpr { + fn arc_children(&self) -> Vec<&Arc> { + self.children() + } + + fn with_new_arc_children( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_if_necessary(arc_self, new_children) + } +} + +/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are +/// two ways to access child plans—directly from the plan and through child nodes—it's +/// recommended to perform mutable operations via [`Self::update_expr_from_children`]. +#[derive(Debug)] +pub struct ExprContext { + /// The physical expression associated with this context. + pub expr: Arc, + /// Custom data payload of the node. + pub data: T, + /// Child contexts of this node. + pub children: Vec, +} + +impl ExprContext { + pub fn new(expr: Arc, data: T, children: Vec) -> Self { + Self { + expr, + data, + children, + } + } + + pub fn update_expr_from_children(mut self) -> Result { + let children_expr = self.children.iter().map(|c| Arc::clone(&c.expr)).collect(); + self.expr = with_new_children_if_necessary(self.expr, children_expr)?; + Ok(self) + } +} + +impl ExprContext { + pub fn new_default(plan: Arc) -> Self { + let children = plan + .children() + .into_iter() + .cloned() + .map(Self::new_default) + .collect(); + Self::new(plan, Default::default(), children) + } +} + +impl Display for ExprContext { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "expr: {:?}", self.expr)?; + write!(f, "data:{}", self.data)?; + write!(f, "") + } +} + +impl ConcreteTreeNode for ExprContext { + fn children(&self) -> &[Self] { + &self.children + } + + fn take_children(mut self) -> (Self, Vec) { + let children = std::mem::take(&mut self.children); + (self, children) + } + + fn with_new_children(mut self, children: Vec) -> Result { + self.children = children; + self.update_expr_from_children() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2ef1597abfd1..1d6919494587 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -23,12 +23,18 @@ use std::sync::Arc; use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; +use crate::physical_expr::PhysicalExpr; use crate::signature::{Signature, TypeSignature}; +use crate::sort_properties::ExprProperties; +use crate::tree_node::ExprContext; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, }; +use arrow::array::MutableArrayData; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow_array::{make_array, Array, ArrayRef, BooleanArray}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -1248,8 +1254,77 @@ impl AggregateOrderSensitivity { } } +/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` +/// are taken, when the mask evaluates `false` values null values are filled. +/// +/// # Arguments +/// * `mask` - Boolean values used to determine where to put the `truthy` values +/// * `truthy` - All values of this array are to scatter according to `mask` into final result. +pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { + let truthy = truthy.to_data(); + + // update the mask so that any null values become false + // (SlicesIterator doesn't respect nulls) + let mask = and_kleene(mask, &is_not_null(mask)?)?; + + let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); + + // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to + // fill with falsy values + + // keep track of how much is filled + let mut filled = 0; + // keep track of current position we have in truthy array + let mut true_pos = 0; + + SlicesIterator::new(&mask).for_each(|(start, end)| { + // the gap needs to be filled with nulls + if start > filled { + mutable.extend_nulls(start - filled); + } + // fill with truthy values + let len = end - start; + mutable.extend(0, true_pos, true_pos + len); + true_pos += len; + filled = end; + }); + // the remaining part is falsy + if filled < mask.len() { + mutable.extend_nulls(mask.len() - filled); + } + + let data = mutable.freeze(); + Ok(make_array(data)) +} + +/// Represents a [`PhysicalExpr`] node with associated properties (order and +/// range) in a context where properties are tracked. +pub type ExprPropertiesNode = ExprContext; + +impl ExprPropertiesNode { + /// Constructs a new `ExprPropertiesNode` with unknown properties for a + /// given physical expression. This node initializes with default properties + /// and recursively applies this to all child expressions. + pub fn new_unknown(expr: Arc) -> Self { + let children = expr + .children() + .into_iter() + .cloned() + .map(Self::new_unknown) + .collect(); + Self { + expr, + data: ExprProperties::new_unknown(), + children, + } + } +} + #[cfg(test)] mod tests { + use arrow_array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, @@ -1696,4 +1771,70 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn scatter_int() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = + Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_int_end_with_false() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, false, true, false, false, false]); + + // output should be same length as mask + let expected = + Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_with_null_mask() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); + let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] + .into_iter() + .collect(); + + // output should treat nulls as though they are false + let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_boolean() -> Result<()> { + let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = BooleanArray::from_iter(vec![ + Some(false), + Some(false), + None, + None, + Some(false), + ]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_boolean_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 56ef32e7ebe0..bcd132ec4910 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,28 +277,29 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match &acc_args.input_types[0] { - // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL - // TODO support for boolean (trivial case) - // https://github.com/apache/datafusion/issues/1109 - DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), - DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), - DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), - DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), - other => { - return not_impl_err!( + let accumulator: Box = + match &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())? { + // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL + // TODO support for boolean (trivial case) + // https://github.com/apache/datafusion/issues/1109 + DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), + DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), + DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), + DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), + other => { + return not_impl_err!( "Support for 'approx_distinct' for data type {other} is not implemented" ) - } - }; + } + }; Ok(accumulator) } } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index e12e3445a83e..f37c164799bd 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_types[0].clone(), + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 844e48f0a44d..8fdd45e71cf3 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -30,21 +30,20 @@ use arrow::{ }; use arrow_schema::{Field, Schema}; +use datafusion_common::DataFusionError; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, - ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; make_udaf_expr_and_func!( ApproxPercentileCont, @@ -105,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] { + let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.dfschema.as_arrow())? { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 @@ -134,24 +133,22 @@ impl ApproxPercentileCont { } } -fn get_lit_value(expr: &Expr) -> datafusion_common::Result { - let empty_schema = Arc::new(Schema::empty()); - let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - let dfschema = DFSchema::empty(); - let expr = - limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, &dfschema)?; - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), +fn get_lit_value( + physical_expr: &Arc, +) -> datafusion_common::Result { + match physical_expr.evaluate(&RecordBatch::new_empty(Arc::new(Schema::empty())))? { ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + ColumnarValue::Array(_) => internal_err!( + "The expr {:?} can't be evaluated to scalar value", + physical_expr + ), } } -fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; +fn validate_input_percentile_expr( + physical_expr: &Arc, +) -> datafusion_common::Result { + let lit = get_lit_value(physical_expr)?; let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, @@ -170,8 +167,10 @@ fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result Ok(percentile) } -fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; +fn validate_input_max_size_expr( + physical_expr: &Arc, +) -> datafusion_common::Result { + let lit = get_lit_value(physical_expr)?; let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 0dbea1fb1ff7..0c62e996763a 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -131,8 +132,8 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { let sub_args = AccumulatorArgs { input_exprs: &[ - acc_args.input_exprs[0].clone(), - acc_args.input_exprs[2].clone(), + Arc::clone(&acc_args.input_exprs[0]), + Arc::clone(&acc_args.input_exprs[2]), ], ..acc_args }; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index bb25de113525..abb344eaf693 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -117,16 +117,15 @@ impl AggregateUDFImpl for ArrayAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + if acc_args.is_distinct { - return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); } if acc_args.sort_exprs.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); } let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( @@ -140,7 +139,7 @@ impl AggregateUDFImpl for ArrayAgg { .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - &acc_args.input_types[0], + &data_type, &ordering_dtypes, ordering_req, acc_args.is_reversed, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 228bce1979a3..f9d10426df0b 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,7 +93,10 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - match (&acc_args.input_types[0], acc_args.data_type) { + let input_type = + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + + match (&input_type, acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,7 +123,7 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - &acc_args.input_types[0], + &input_type, acc_args.data_type ), } @@ -154,10 +157,12 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (&args.input_types[0], args.data_type) { + let sum_data_type = &args.input_exprs[0].data_type(args.dfschema.as_arrow())?; + + match (sum_data_type, args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -176,7 +181,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, avg_fn, ))) @@ -197,7 +202,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, avg_fn, ))) @@ -205,7 +210,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - &args.input_types[0], + sum_data_type, args.data_type ), } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index e2d59003fca1..aacff28baeea 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -148,7 +148,8 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = &acc_args.input_types[0]; + let data_type = + &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index febf1fcd2fef..a0a1dbeb4d3c 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = &acc_args.input_types[0]; + let dt = &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 7c4b9a7f06c6..6362bdcc9287 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -30,10 +30,11 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValu use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, ReversedUDAF, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::expressions::Literal; use datafusion_physical_expr_common::sort_expr::{ limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, PhysicalSortExpr, @@ -87,20 +88,26 @@ impl AggregateUDFImpl for NthValueAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let n = match acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Int64(Some(value))) => { - if acc_args.is_reversed { - Ok(-value) - } else { - Ok(value) + let Some(n) = acc_args.input_exprs[1] + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + ScalarValue::Int64(Some(n)) => { + if acc_args.is_reversed { + Some(-n) + } else { + Some(*n) + } } - } - _ => not_impl_err!( + _ => None, + }) + else { + return not_impl_err!( "{} not supported for n: {}", self.name(), &acc_args.input_exprs[1] - ), - }?; + ); + }; let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, @@ -114,7 +121,7 @@ impl AggregateUDFImpl for NthValueAgg { NthValueAccumulator::try_new( n, - &acc_args.input_types[0], + &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, &ordering_dtypes, ordering_req, ) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 1d2257d90133..caa27d059729 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -274,9 +274,9 @@ mod tests { use arrow::{array::*, datatypes::*}; use datafusion_common::DFSchema; + use datafusion_expr::expressions::column::{col, Column}; use datafusion_expr::AggregateUDF; use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; - use datafusion_physical_expr_common::expressions::column::col; use super::*; @@ -334,8 +334,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + input_exprs: &[Arc::new(Column::new("a", 0))], }; let args2 = AccumulatorArgs { @@ -346,8 +345,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + input_exprs: &[Arc::new(Column::new("a", 0))], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 371cc8fb9739..9d16616e1c9a 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,8 +24,9 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, }; +use datafusion_physical_expr_common::expressions::Literal; use std::any::Any; make_udaf_expr_and_func!( @@ -82,21 +83,25 @@ impl AggregateUDFImpl for StringAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - match &acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Utf8(Some(delimiter))) - | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Expr::Literal(ScalarValue::Utf8(None)) - | Expr::Literal(ScalarValue::LargeUtf8(None)) - | Expr::Literal(ScalarValue::Null) => { - Ok(Box::new(StringAggAccumulator::new(""))) - } - _ => not_impl_err!( + let Some(delimiter) = acc_args.input_exprs[1] + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + ScalarValue::Utf8(Some(s)) => Some(s.as_str()), + ScalarValue::LargeUtf8(Some(s)) => Some(s.as_str()), + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Null => Some(""), + _ => None, + }) + else { + return not_impl_err!( "StringAgg not supported for delimiter {}", &acc_args.input_exprs[1] - ), - } + ); + }; + + Ok(Box::new(StringAggAccumulator::new(delimiter))) } } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0707301b2557..acf1d3c4aef6 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -18,9 +18,9 @@ use std::fmt::Debug; use std::{any::Any, sync::Arc}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, ToDFSchema}; use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; @@ -30,9 +30,9 @@ use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; -use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; +use datafusion_expr::physical_expr::PhysicalExpr; use self::utils::down_cast_any_ref; @@ -76,7 +76,7 @@ pub fn create_aggregate_expr( builder = builder.sort_exprs(sort_exprs.to_vec()); builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.schema(Arc::new(schema.clone())); + builder = builder.dfschema(Arc::new(schema.clone()).to_dfschema()?); builder = builder.name(name); if ignore_nulls { @@ -109,8 +109,6 @@ pub fn create_aggregate_expr_with_dfschema( builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); builder = builder.dfschema(dfschema.clone()); - let schema: Schema = dfschema.into(); - builder = builder.schema(Arc::new(schema)); builder = builder.name(name); if ignore_nulls { @@ -138,8 +136,6 @@ pub struct AggregateExprBuilder { /// Logical expressions of the aggregate function, it will be deprecated in logical_args: Vec, name: String, - /// Arrow Schema for the aggregate function - schema: SchemaRef, /// Datafusion Schema for the aggregate function dfschema: DFSchema, /// The logical order by expressions, it will be deprecated in @@ -161,7 +157,6 @@ impl AggregateExprBuilder { args, logical_args: vec![], name: String::new(), - schema: Arc::new(Schema::empty()), dfschema: DFSchema::empty(), sort_exprs: vec![], ordering_req: vec![], @@ -177,7 +172,6 @@ impl AggregateExprBuilder { args, logical_args, name, - schema, dfschema, sort_exprs, ordering_req, @@ -195,7 +189,7 @@ impl AggregateExprBuilder { if !ordering_req.is_empty() { let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(&schema)) + .map(|e| e.expr.data_type(dfschema.as_arrow())) .collect::>>()?; ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); @@ -203,7 +197,7 @@ impl AggregateExprBuilder { let input_exprs_types = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.data_type(dfschema.as_arrow())) .collect::>>()?; check_arg_count( @@ -236,11 +230,6 @@ impl AggregateExprBuilder { self } - pub fn schema(mut self, schema: SchemaRef) -> Self { - self.schema = schema; - self - } - pub fn dfschema(mut self, dfschema: DFSchema) -> Self { self.dfschema = dfschema; self @@ -524,8 +513,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -540,8 +528,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -611,8 +598,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -626,8 +612,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs index dd6131ad65c3..2b0058bb338f 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -20,7 +20,7 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use datafusion_expr::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; @@ -235,7 +235,7 @@ pub fn cast( mod tests { use super::*; - use crate::expressions::column::col; + use datafusion_expr::expressions::column::col; use arrow::{ array::{ diff --git a/datafusion/physical-expr-common/src/expressions/literal.rs b/datafusion/physical-expr-common/src/expressions/literal.rs index b3cff1ef69ba..1be46d13f5fb 100644 --- a/datafusion/physical-expr-common/src/expressions/literal.rs +++ b/datafusion/physical-expr-common/src/expressions/literal.rs @@ -21,14 +21,13 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; - use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::physical_expr::{down_cast_any_ref, PhysicalExpr}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, Expr}; diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index dd534cc07d20..b53bdc829440 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -16,7 +16,6 @@ // under the License. mod cast; -pub mod column; pub mod literal; pub use cast::{cast, cast_with_options, CastExpr}; diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f03eedd4cf65..9a694b3478c4 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -20,7 +20,5 @@ pub mod binary_map; pub mod binary_view_map; pub mod datum; pub mod expressions; -pub mod physical_expr; pub mod sort_expr; -pub mod tree_node; pub mod utils; diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 2b506b74216f..9f9beb3190e6 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -21,13 +21,12 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::PhysicalExpr; use crate::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; - use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use datafusion_common::{exec_err, DFSchema, Result}; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::{ColumnarValue, Expr}; /// Represents Sort operation for a column in a RecordBatch diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs deleted file mode 100644 index d9892ce55509..000000000000 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module provides common traits for visiting or rewriting tree nodes easily. - -use std::fmt::{self, Display, Formatter}; -use std::sync::Arc; - -use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; - -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; -use datafusion_common::Result; - -impl DynTreeNode for dyn PhysicalExpr { - fn arc_children(&self) -> Vec<&Arc> { - self.children() - } - - fn with_new_arc_children( - &self, - arc_self: Arc, - new_children: Vec>, - ) -> Result> { - with_new_children_if_necessary(arc_self, new_children) - } -} - -/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are -/// two ways to access child plans—directly from the plan and through child nodes—it's -/// recommended to perform mutable operations via [`Self::update_expr_from_children`]. -#[derive(Debug)] -pub struct ExprContext { - /// The physical expression associated with this context. - pub expr: Arc, - /// Custom data payload of the node. - pub data: T, - /// Child contexts of this node. - pub children: Vec, -} - -impl ExprContext { - pub fn new(expr: Arc, data: T, children: Vec) -> Self { - Self { - expr, - data, - children, - } - } - - pub fn update_expr_from_children(mut self) -> Result { - let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - self.expr = with_new_children_if_necessary(self.expr, children_expr)?; - Ok(self) - } -} - -impl ExprContext { - pub fn new_default(plan: Arc) -> Self { - let children = plan - .children() - .into_iter() - .cloned() - .map(Self::new_default) - .collect(); - Self::new(plan, Default::default(), children) - } -} - -impl Display for ExprContext { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "expr: {:?}", self.expr)?; - write!(f, "data:{}", self.data)?; - write!(f, "") - } -} - -impl ConcreteTreeNode for ExprContext { - fn children(&self) -> &[Self] { - &self.children - } - - fn take_children(mut self) -> (Self, Vec) { - let children = std::mem::take(&mut self.children); - (self, children) - } - - fn with_new_children(mut self, children: Vec) -> Result { - self.children = children; - self.update_expr_from_children() - } -} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 0978a906a5dc..f4a1616c9131 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,86 +17,14 @@ use std::sync::Arc; -use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; -use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; - -use datafusion_common::{exec_err, DFSchema, Result}; -use datafusion_expr::expr::Alias; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::Expr; - -use crate::expressions::column::Column; use crate::expressions::literal::Literal; use crate::expressions::CastExpr; -use crate::physical_expr::PhysicalExpr; use crate::sort_expr::PhysicalSortExpr; -use crate::tree_node::ExprContext; - -/// Represents a [`PhysicalExpr`] node with associated properties (order and -/// range) in a context where properties are tracked. -pub type ExprPropertiesNode = ExprContext; - -impl ExprPropertiesNode { - /// Constructs a new `ExprPropertiesNode` with unknown properties for a - /// given physical expression. This node initializes with default properties - /// and recursively applies this to all child expressions. - pub fn new_unknown(expr: Arc) -> Self { - let children = expr - .children() - .into_iter() - .cloned() - .map(Self::new_unknown) - .collect(); - Self { - expr, - data: ExprProperties::new_unknown(), - children, - } - } -} - -/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` -/// are taken, when the mask evaluates `false` values null values are filled. -/// -/// # Arguments -/// * `mask` - Boolean values used to determine where to put the `truthy` values -/// * `truthy` - All values of this array are to scatter according to `mask` into final result. -pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { - let truthy = truthy.to_data(); - - // update the mask so that any null values become false - // (SlicesIterator doesn't respect nulls) - let mask = and_kleene(mask, &is_not_null(mask)?)?; - - let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); - - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values - - // keep track of how much is filled - let mut filled = 0; - // keep track of current position we have in truthy array - let mut true_pos = 0; - - SlicesIterator::new(&mask).for_each(|(start, end)| { - // the gap needs to be filled with nulls - if start > filled { - mutable.extend_nulls(start - filled); - } - // fill with truthy values - let len = end - start; - mutable.extend(0, true_pos, true_pos + len); - true_pos += len; - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - mutable.extend_nulls(mask.len() - filled); - } - - let data = mutable.freeze(); - Ok(make_array(data)) -} +use datafusion_common::{exec_err, DFSchema, Result}; +use datafusion_expr::expr::Alias; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; +use datafusion_expr::Expr; /// Reverses the ORDER BY expression, which is useful during equivalent window /// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into @@ -136,80 +64,3 @@ pub fn limited_convert_logical_expr_to_physical_expr_with_dfschema( ), } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::Int32Array; - - use datafusion_common::cast::{as_boolean_array, as_int32_array}; - - use super::*; - - #[test] - fn scatter_int() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = - Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_int_end_with_false() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, false, true, false, false, false]); - - // output should be same length as mask - let expected = - Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_with_null_mask() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); - let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] - .into_iter() - .collect(); - - // output should treat nulls as though they are false - let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_boolean() -> Result<()> { - let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = BooleanArray::from_iter(vec![ - Some(false), - Some(false), - None, - None, - Some(false), - ]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_boolean_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } -} diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 862edd9c1fac..74c52a9294cd 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -21,11 +21,11 @@ use arrow_array::builder::{Int32Builder, StringBuilder}; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; -use datafusion_physical_expr_common::expressions::column::Column; use datafusion_physical_expr_common::expressions::Literal; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; fn make_col(name: &str, index: usize) -> Arc { diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs index 3dad8e9b456a..3cc7c4f0c37b 100644 --- a/datafusion/physical-expr/benches/is_null.rs +++ b/datafusion/physical-expr/benches/is_null.rs @@ -20,9 +20,9 @@ use arrow::record_batch::RecordBatch; use arrow_array::builder::Int32Builder; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; -use datafusion_physical_expr_common::expressions::column::Column; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a6e9fba28167..fcf3ac81c5c2 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -34,12 +34,12 @@ use crate::{ use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, JoinSide, JoinType, Result}; +use datafusion_expr::expressions::column::Column; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::physical_expr::with_new_schema; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_expr::utils::ExprPropertiesNode; use datafusion_physical_expr_common::expressions::CastExpr; -use datafusion_physical_expr_common::physical_expr::with_new_schema; -use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c34dcdfb7598..b7d2202b9f31 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -684,8 +684,8 @@ mod tests { use crate::expressions::{col, lit, try_cast, Literal}; use datafusion_common::plan_datafusion_err; + use datafusion_expr::expressions::column::Column; use datafusion_expr::type_coercion::binary::get_input_types; - use datafusion_physical_expr_common::expressions::column::Column; /// Performs a binary operation, applying any type coercion necessary fn binary_op( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b428d562bd1b..b433f6c5e3d2 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -32,7 +32,7 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_expr::expressions::column::Column; use datafusion_physical_expr_common::expressions::Literal; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7cbe4e796844..951bef4521e3 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,8 +47,8 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; +pub use datafusion_expr::expressions::column::{col, Column}; pub use datafusion_expr::utils::format_state_name; -pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 2e78119eba46..0fa01fc689d6 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -55,7 +55,7 @@ pub use physical_expr::{ PhysicalExprRef, }; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_expr::physical_expr::PhysicalExpr; pub use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, @@ -69,5 +69,5 @@ pub use utils::split_conjunction; // For backwards compatibility pub mod tree_node { - pub use datafusion_physical_expr_common::tree_node::ExprContext; + pub use datafusion_expr::tree_node::ExprContext; } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c60a772b9ce2..942b92abc43c 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,10 +17,10 @@ use std::sync::Arc; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_expr::physical_expr::PhysicalExpr; use itertools::izip; -pub use datafusion_physical_expr_common::physical_expr::down_cast_any_ref; +pub use datafusion_expr::physical_expr::down_cast_any_ref; /// Shared [`PhysicalExpr`]. pub type PhysicalExprRef = Arc; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d1152038eb2a..fd2510fbf90d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1195,7 +1195,7 @@ mod tests { use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, - DataFusionError, ScalarValue, + DataFusionError, ScalarValue, ToDFSchema, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; @@ -1352,7 +1352,7 @@ mod tests { }; let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("COUNT(1)") .logical_exprs(vec![datafusion_expr::lit(1i8)]) .build()?]; @@ -1497,7 +1497,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -1793,7 +1793,7 @@ mod tests { // Median(a) fn test_median_agg_expr(schema: SchemaRef) -> Result> { AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) - .schema(schema) + .dfschema(schema.to_dfschema()?) .name("MEDIAN(a)") .build() } @@ -1824,7 +1824,7 @@ mod tests { let aggregates_v2: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -1884,7 +1884,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(a)") .build()?, ]; @@ -1924,7 +1924,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -2353,7 +2353,7 @@ mod tests { let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("1") .build()?]; diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9321fdb2cadf..41f0878595b7 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -586,8 +586,8 @@ mod tests { use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; + use datafusion_expr::expressions::column::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::expressions::column::col; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a462430ca381..dbfea253959e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, @@ -145,7 +145,7 @@ pub fn create_window_expr( .collect::>(); let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) + .dfschema(Arc::new(input_schema.clone()).to_dfschema()?) .name(name) .order_by(order_by.to_vec()) .sort_exprs(sort_exprs) @@ -413,6 +413,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } } +#[allow(clippy::needless_borrow)] pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, @@ -430,7 +431,7 @@ pub(crate) fn calc_requirements< let PhysicalSortExpr { expr, options } = element.borrow(); if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(expr), + Arc::clone(&expr), Some(*options), )); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1f433ff01d12..3c0d6664da17 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -61,7 +61,9 @@ use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, +}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use crate::common::{byte_to_string, str_to_byte}; @@ -509,7 +511,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() + AggregateExprBuilder::new(agg_udf, input_phy_expr).dfschema(Arc::clone(&physical_schema).to_dfschema()?).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3ddc122e3de2..caab8f0a77f7 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -79,7 +79,9 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, +}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, @@ -295,7 +297,7 @@ fn roundtrip_window() -> Result<()> { avg_udaf(), vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], ) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("avg(b)") .build()?, &[], @@ -311,7 +313,7 @@ fn roundtrip_window() -> Result<()> { let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") .build()?; @@ -345,17 +347,17 @@ fn rountrip_aggregate() -> Result<()> { vec![(col("a", &schema)?, "unused".to_string())]; let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?; let nth_expr = AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("NTH_VALUE(b, 1)") .build()?; let str_agg_expr = AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("NTH_VALUE(b, 1)") .build()?; @@ -395,7 +397,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -462,7 +464,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("example_agg") .build()?, ]; @@ -957,7 +959,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(Literal::new(ScalarValue::from(42)))]; let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("aggregate_udf") .build()?; @@ -982,7 +984,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("aggregate_udf") .distinct() .ignore_nulls() diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc37..e8698253edb5 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,7 +27,6 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; -#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?; From d082d5162572c1c69626d580843c46359ccb21c9 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 1 Aug 2024 18:51:53 +0800 Subject: [PATCH 3/7] fix 180 clippy --- .../core/src/datasource/physical_plan/parquet/statistics.rs | 1 - datafusion/functions-aggregate/src/approx_distinct.rs | 2 +- datafusion/substrait/src/serializer.rs | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index eec7c95fff94..11b8f5fc6c79 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -517,7 +517,6 @@ macro_rules! make_data_page_stats_iterator { } } - #[allow(clippy::redundant_closure_call)] impl<'a, I> Iterator for $iterator_type<'a, I> where I: Iterator, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index bcd132ec4910..26909471fc67 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -297,7 +297,7 @@ impl AggregateUDFImpl for ApproxDistinct { other => { return not_impl_err!( "Support for 'approx_distinct' for data type {other} is not implemented" - ) + ) } }; Ok(accumulator) diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index e8698253edb5..6b81e33dfc37 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,6 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; +#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?; From 5d973451c12c85eeb262480c9e2f0323efd8a4dc Mon Sep 17 00:00:00 2001 From: Xin Li Date: Sun, 4 Aug 2024 19:39:06 +0800 Subject: [PATCH 4/7] revert schema change v1 --- .../aggregate_statistics.rs | 3 +- .../combine_partial_final_agg.rs | 3 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 3 +- datafusion/expr/src/function.rs | 5 +++- .../src/approx_distinct.rs | 2 +- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- .../functions-aggregate/src/array_agg.rs | 4 +-- datafusion/functions-aggregate/src/average.rs | 4 +-- datafusion/functions-aggregate/src/count.rs | 2 +- .../functions-aggregate/src/first_last.rs | 4 +-- datafusion/functions-aggregate/src/median.rs | 2 +- .../functions-aggregate/src/nth_value.rs | 4 +-- datafusion/functions-aggregate/src/stddev.rs | 2 ++ .../physical-expr-common/src/aggregate/mod.rs | 28 +++++++++++++++---- datafusion/physical-plan/src/windows/mod.rs | 4 +-- datafusion/proto/src/physical_plan/mod.rs | 7 ++--- .../tests/cases/roundtrip_physical_plan.rs | 20 ++++++------- 18 files changed, 58 insertions(+), 43 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index cde8bb241ee4..a0f6f6a65b1f 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -313,7 +313,6 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; - use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; @@ -411,7 +410,7 @@ pub(crate) mod tests { // Return appropriate expr depending if COUNT is for col or table (*) pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![self.column()]) - .dfschema(schema.clone().to_dfschema().unwrap()) + .schema(Arc::new(schema.clone())) .name(self.column_name()) .build() .unwrap() diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index ab547b86f582..63e9bac81fba 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -174,7 +174,6 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; @@ -280,7 +279,7 @@ mod tests { schema: &Schema, ) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) - .dfschema(schema.clone().to_dfschema().unwrap()) + .schema(schema.clone()) .name(name) .build() .unwrap() diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 31fa59af8c18..6f286c9aeba1 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -32,7 +32,6 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; @@ -107,7 +106,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let aggregate_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .dfschema(Arc::clone(&schema).to_dfschema().unwrap()) + .schema(Arc::clone(&schema)) .name("sum1") .build() .unwrap(), diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1ed793228ded..2c1e0110ba68 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -20,7 +20,7 @@ use crate::physical_expr::PhysicalExpr; use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result}; use std::sync::Arc; @@ -55,6 +55,9 @@ pub struct AccumulatorArgs<'a> { /// The return type of the aggregate function. pub data_type: &'a DataType, + /// The schema of the input arguments + pub schema: &'a Schema, + /// The schema of the input arguments pub dfschema: &'a DFSchema, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 26909471fc67..4bcb646291ce 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -278,7 +278,7 @@ impl AggregateUDFImpl for ApproxDistinct { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let accumulator: Box = - match &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())? { + match &acc_args.input_exprs[0].data_type(acc_args.schema)? { // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL // TODO support for boolean (trivial case) // https://github.com/apache/datafusion/issues/1109 diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 85b79cdd5267..f7a7be7723da 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, + acc_args.input_exprs[0].data_type(acc_args.schema)?, ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 03ab32e5ab03..efb9dc503a32 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -104,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.dfschema.as_arrow())? { + let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.schema)? { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index abb344eaf693..06750754006e 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -118,7 +118,7 @@ impl AggregateUDFImpl for ArrayAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let data_type = - acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + acc_args.input_exprs[0].data_type(acc_args.schema)?; if acc_args.is_distinct { return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); @@ -135,7 +135,7 @@ impl AggregateUDFImpl for ArrayAgg { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) + .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f9d10426df0b..e27acb31a9d4 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -94,7 +94,7 @@ impl AggregateUDFImpl for Avg { use DataType::*; // instantiate specialized accumulator based for the type let input_type = - acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + acc_args.input_exprs[0].data_type(acc_args.schema)?; match (&input_type, acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), @@ -157,7 +157,7 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - let sum_data_type = &args.input_exprs[0].data_type(args.dfschema.as_arrow())?; + let sum_data_type = &args.input_exprs[0].data_type(args.schema)?; match (sum_data_type, args.data_type) { (Float64, Float64) => { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 4668f1caf101..6b241fde4e8e 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -149,7 +149,7 @@ impl AggregateUDFImpl for Count { } let data_type = - &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + &acc_args.input_exprs[0].data_type(acc_args.schema)?; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 7563506b3b65..587767b8e356 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -124,7 +124,7 @@ impl AggregateUDFImpl for FirstValue { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) + .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; // When requirement is empty, or it is signalled by outside caller that @@ -423,7 +423,7 @@ impl AggregateUDFImpl for LastValue { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) + .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index a0a1dbeb4d3c..fe12055e8712 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + let dt = &acc_args.input_exprs[0].data_type(acc_args.schema)?; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 6362bdcc9287..af279db3ed8f 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -116,12 +116,12 @@ impl AggregateUDFImpl for NthValueAgg { let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(acc_args.dfschema.as_arrow())) + .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; NthValueAccumulator::try_new( n, - &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, + &acc_args.input_exprs[0].data_type(acc_args.schema)?, &ordering_dtypes, ordering_req, ) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index caa27d059729..859f42262eea 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -328,6 +328,7 @@ mod tests { let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { data_type: &DataType::Float64, + schema: &schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], @@ -339,6 +340,7 @@ mod tests { let args2 = AccumulatorArgs { data_type: &DataType::Float64, + schema: &schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index acf1d3c4aef6..8d5d0ffaa793 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -18,9 +18,9 @@ use std::fmt::Debug; use std::{any::Any, sync::Arc}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{exec_err, ToDFSchema}; +use datafusion_common::exec_err; use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; @@ -76,7 +76,7 @@ pub fn create_aggregate_expr( builder = builder.sort_exprs(sort_exprs.to_vec()); builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.dfschema(Arc::new(schema.clone()).to_dfschema()?); + builder = builder.schema(Arc::new(schema.clone())); builder = builder.name(name); if ignore_nulls { @@ -108,7 +108,8 @@ pub fn create_aggregate_expr_with_dfschema( builder = builder.sort_exprs(sort_exprs.to_vec()); builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.dfschema(dfschema.clone()); + let schema: Schema = dfschema.into(); + builder = builder.schema(Arc::new(schema)); builder = builder.name(name); if ignore_nulls { @@ -136,6 +137,8 @@ pub struct AggregateExprBuilder { /// Logical expressions of the aggregate function, it will be deprecated in logical_args: Vec, name: String, + /// Arrow Schema for the aggregate function + schema: SchemaRef, /// Datafusion Schema for the aggregate function dfschema: DFSchema, /// The logical order by expressions, it will be deprecated in @@ -157,6 +160,7 @@ impl AggregateExprBuilder { args, logical_args: vec![], name: String::new(), + schema: Arc::new(Schema::empty()), dfschema: DFSchema::empty(), sort_exprs: vec![], ordering_req: vec![], @@ -172,6 +176,7 @@ impl AggregateExprBuilder { args, logical_args, name, + schema, dfschema, sort_exprs, ordering_req, @@ -189,7 +194,7 @@ impl AggregateExprBuilder { if !ordering_req.is_empty() { let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(dfschema.as_arrow())) + .map(|e| e.expr.data_type(&schema)) .collect::>>()?; ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); @@ -197,7 +202,7 @@ impl AggregateExprBuilder { let input_exprs_types = args .iter() - .map(|arg| arg.data_type(dfschema.as_arrow())) + .map(|arg| arg.data_type(&schema)) .collect::>>()?; check_arg_count( @@ -214,6 +219,7 @@ impl AggregateExprBuilder { logical_args, data_type, name, + schema: Arc::unwrap_or_clone(schema), dfschema, sort_exprs, ordering_req, @@ -230,6 +236,11 @@ impl AggregateExprBuilder { self } + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + pub fn dfschema(mut self, dfschema: DFSchema) -> Self { self.dfschema = dfschema; self @@ -444,6 +455,7 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + schema: Schema, dfschema: DFSchema, // The logical order by expressions sort_exprs: Vec, @@ -509,6 +521,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { data_type: &self.data_type, + schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -524,6 +537,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { data_type: &self.data_type, + schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -594,6 +608,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { data_type: &self.data_type, + schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, @@ -608,6 +623,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { data_type: &self.data_type, + schema: &self.schema, dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 2bd33a476c2f..70e11498c88f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -30,7 +30,7 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue, ToDFSchema}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, @@ -127,7 +127,7 @@ pub fn create_window_expr( .collect::>(); let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .dfschema(Arc::new(input_schema.clone()).to_dfschema()?) + .schema(Arc::new(input_schema.clone())) .name(name) .order_by(order_by.to_vec()) .sort_exprs(sort_exprs) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a79eafe43846..8932cb883e26 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -61,9 +61,7 @@ use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; -use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, -}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use crate::common::{byte_to_string, str_to_byte}; @@ -491,8 +489,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - AggregateExprBuilder::new(agg_udf, input_phy_expr).dfschema(Arc::clone(&physical_schema).to_dfschema()?).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() - } + AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } } }).transpose()?.ok_or_else(|| { proto_error("Invalid AggregateExpr, missing aggregate_function") diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9c180e219b5b..213a5590b742 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -82,7 +82,7 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, + internal_err, not_impl_err, DataFusionError, Result, }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, @@ -297,7 +297,7 @@ fn roundtrip_window() -> Result<()> { avg_udaf(), vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], ) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("avg(b)") .build()?, &[], @@ -313,7 +313,7 @@ fn roundtrip_window() -> Result<()> { let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") .build()?; @@ -347,17 +347,17 @@ fn rountrip_aggregate() -> Result<()> { vec![(col("a", &schema)?, "unused".to_string())]; let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("AVG(b)") .build()?; let nth_expr = AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("NTH_VALUE(b, 1)") .build()?; let str_agg_expr = AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("NTH_VALUE(b, 1)") .build()?; @@ -397,7 +397,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("AVG(b)") .build()?, ]; @@ -464,7 +464,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("example_agg") .build()?, ]; @@ -966,7 +966,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(Literal::new(ScalarValue::from(42)))]; let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("aggregate_udf") .build()?; @@ -991,7 +991,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("aggregate_udf") .distinct() .ignore_nulls() From 1ccc74042125fdf8bbb7d0999ff5f750ec6c438c Mon Sep 17 00:00:00 2001 From: Xin Li Date: Sun, 4 Aug 2024 19:46:30 +0800 Subject: [PATCH 5/7] revert schema change v2 --- datafusion/functions-aggregate/src/array_agg.rs | 3 +-- datafusion/functions-aggregate/src/average.rs | 3 +-- datafusion/functions-aggregate/src/count.rs | 3 +-- datafusion/physical-plan/src/aggregates/mod.rs | 16 ++++++++-------- datafusion/physical-plan/src/windows/mod.rs | 3 +-- .../proto/tests/cases/roundtrip_physical_plan.rs | 4 +--- 6 files changed, 13 insertions(+), 19 deletions(-) diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 06750754006e..35e7c417ade1 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -117,8 +117,7 @@ impl AggregateUDFImpl for ArrayAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let data_type = - acc_args.input_exprs[0].data_type(acc_args.schema)?; + let data_type = acc_args.input_exprs[0].data_type(acc_args.schema)?; if acc_args.is_distinct { return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index e27acb31a9d4..77a6c4f7457e 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,8 +93,7 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - let input_type = - acc_args.input_exprs[0].data_type(acc_args.schema)?; + let input_type = acc_args.input_exprs[0].data_type(acc_args.schema)?; match (&input_type, acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 6b241fde4e8e..51c58b73f181 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -148,8 +148,7 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = - &acc_args.input_exprs[0].data_type(acc_args.schema)?; + let data_type = &acc_args.input_exprs[0].data_type(acc_args.schema)?; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9d7f45603464..e54cd5c6ae96 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1192,7 +1192,7 @@ mod tests { use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, - DataFusionError, ScalarValue, ToDFSchema, + DataFusionError, ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; @@ -1349,7 +1349,7 @@ mod tests { }; let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .dfschema(Arc::clone(&input_schema).to_dfschema()?) + .schema(Arc::clone(&input_schema)) .name("COUNT(1)") .logical_exprs(vec![datafusion_expr::lit(1i8)]) .build()?]; @@ -1494,7 +1494,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .dfschema(Arc::clone(&input_schema).to_dfschema()?) + .schema(Arc::clone(&input_schema)) .name("AVG(b)") .build()?, ]; @@ -1790,7 +1790,7 @@ mod tests { // Median(a) fn test_median_agg_expr(schema: SchemaRef) -> Result> { AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) - .dfschema(schema.to_dfschema()?) + .schema(Arc::clone(&schema)) .name("MEDIAN(a)") .build() } @@ -1821,7 +1821,7 @@ mod tests { let aggregates_v2: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .dfschema(Arc::clone(&input_schema).to_dfschema()?) + .schema(Arc::clone(&input_schema)) .name("AVG(b)") .build()?, ]; @@ -1881,7 +1881,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("AVG(a)") .build()?, ]; @@ -1921,7 +1921,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("AVG(b)") .build()?, ]; @@ -2350,7 +2350,7 @@ mod tests { let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("1") .build()?]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 70e11498c88f..65cef28efc45 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -395,7 +395,6 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } } -#[allow(clippy::needless_borrow)] pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, @@ -413,7 +412,7 @@ pub(crate) fn calc_requirements< let PhysicalSortExpr { expr, options } = element.borrow(); if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(&expr), + Arc::clone(expr), Some(*options), )); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 213a5590b742..0e2bc9cbb3e2 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -81,9 +81,7 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, -}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, From 4d50a944e80da093948217ea9d84606acf4ffb4f Mon Sep 17 00:00:00 2001 From: Xin Li Date: Mon, 5 Aug 2024 12:38:18 +0800 Subject: [PATCH 6/7] revert schema change v3 --- .../core/src/physical_optimizer/combine_partial_final_agg.rs | 4 ++-- datafusion/physical-expr-common/src/aggregate/mod.rs | 1 + datafusion/physical-plan/src/aggregates/mod.rs | 2 +- datafusion/proto/src/physical_plan/mod.rs | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 63e9bac81fba..6f3274820c8c 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -279,7 +279,7 @@ mod tests { schema: &Schema, ) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) - .schema(schema.clone()) + .schema(Arc::new(schema.clone())) .name(name) .build() .unwrap() @@ -363,7 +363,7 @@ mod tests { let aggr_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) - .dfschema(Arc::clone(&schema).to_dfschema()?) + .schema(Arc::clone(&schema)) .name("Sum(b)") .build() .unwrap(), diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 8d5d0ffaa793..ec830e6e97e1 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -108,6 +108,7 @@ pub fn create_aggregate_expr_with_dfschema( builder = builder.sort_exprs(sort_exprs.to_vec()); builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); + builder = builder.dfschema(dfschema.clone()); let schema: Schema = dfschema.into(); builder = builder.schema(Arc::new(schema)); builder = builder.name(name); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e54cd5c6ae96..43f9f98283bb 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1790,7 +1790,7 @@ mod tests { // Median(a) fn test_median_agg_expr(schema: SchemaRef) -> Result> { AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) + .schema(schema) .name("MEDIAN(a)") .build() } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8932cb883e26..fbb9e442980b 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -489,7 +489,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } + AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() + } } }).transpose()?.ok_or_else(|| { proto_error("Invalid AggregateExpr, missing aggregate_function") From ca4029954a414c8e73cd2136d4b61321c29eb602 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Mon, 5 Aug 2024 12:53:28 +0800 Subject: [PATCH 7/7] fix clippy --- datafusion/functions-aggregate/src/stddev.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 859f42262eea..1e30c81f76aa 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -328,7 +328,7 @@ mod tests { let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { data_type: &DataType::Float64, - schema: &schema, + schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], @@ -340,7 +340,7 @@ mod tests { let args2 = AccumulatorArgs { data_type: &DataType::Float64, - schema: &schema, + schema, dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[],