diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index c1ee946df77e8..7b1d3e94b2efe 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -25,6 +25,7 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion::common::{cast::as_float64_array, ScalarValue}; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -92,10 +93,10 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_type().clone(), true), - Field::new("n", DataType::UInt32, true), + Field::new("prod", args.return_type().clone(), true).into(), + Field::new("n", DataType::UInt32, true).into(), ]) } @@ -401,7 +402,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 8330e783319d5..4f00e04e7e993 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -23,6 +23,7 @@ use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::Float64Type, }; +use arrow_schema::FieldRef; use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg_udaf; @@ -87,8 +88,8 @@ impl WindowUDFImpl for SmoothItUdf { Ok(Box::new(MyPartitionEvaluator::new())) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } @@ -205,8 +206,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { Some(Box::new(simplify)) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index e00a44188e575..7c00d323a8e69 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -128,7 +128,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - Field::new("f", DataType::Int32, true), + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -193,7 +193,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - Field::new("f", DataType::Int32, true), + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -261,7 +261,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - Field::new("f", DataType::Int32, true), + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -326,7 +326,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - Field::new("f", DataType::Int32, true), + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 203fb6e85237e..ae517795ab955 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -32,6 +32,7 @@ use arrow::array::{ StringArray, StructArray, UInt64Array, }; use arrow::datatypes::{Fields, Schema}; +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; @@ -572,7 +573,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_fields = vec![Field::new("sum", timestamp_type, true)]; + let state_fields = vec![Field::new("sum", timestamp_type, true).into()]; let volatility = Volatility::Immutable; @@ -672,7 +673,7 @@ impl FirstSelector { let state_fields = state_type .into_iter() .enumerate() - .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(|(i, t)| Field::new(format!("{i}"), t, true).into()) .collect::>(); // Possible input signatures @@ -932,9 +933,10 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf { unimplemented!("this should never be called since return_field is implemented"); } - fn return_field(&self, _arg_fields: &[Field]) -> Result { + fn return_field(&self, _arg_fields: &[FieldRef]) -> Result { Ok(Field::new(self.name(), DataType::UInt64, true) - .with_metadata(self.metadata.clone())) + .with_metadata(self.metadata.clone()) + .into()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b5960ae5bd8d9..25458efa4fa55 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -28,7 +28,7 @@ use arrow::array::{ use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, FieldRef}; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -814,7 +814,7 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { if args.arg_fields.len() != 3 { return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } @@ -845,7 +845,8 @@ impl ScalarUDFImpl for TakeUDF { self.name(), args.arg_fields[take_idx].data_type().to_owned(), true, - )) + ) + .into()) } // The actual implementation @@ -1412,9 +1413,10 @@ impl ScalarUDFImpl for MetadataBasedUdf { ); } - fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { Ok(Field::new(self.name(), DataType::UInt64, true) - .with_metadata(self.metadata.clone())) + .with_metadata(self.metadata.clone()) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -1562,14 +1564,15 @@ impl ScalarUDFImpl for ExtensionBasedUdf { Ok(DataType::Utf8) } - fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) - .with_extension_type(MyUserExtentionType {})) + .with_extension_type(MyUserExtentionType {}) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { assert_eq!(args.arg_fields.len(), 1); - let input_field = args.arg_fields[0]; + let input_field = args.arg_fields[0].as_ref(); let output_as_bool = matches!( CanonicalExtensionType::try_from(input_field), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 6798c0d308de7..bcd2c3945e392 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -23,6 +23,7 @@ use arrow::array::{ UInt64Array, }; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; @@ -564,8 +565,8 @@ impl OddCounter { &self.aliases } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Int64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Int64, true).into()) } } @@ -683,7 +684,7 @@ impl WindowUDFImpl for VariadicWindowUDF { unimplemented!("unnecessary for testing"); } - fn field(&self, _: WindowUDFFieldArgs) -> Result { + fn field(&self, _: WindowUDFFieldArgs) -> Result { unimplemented!("unnecessary for testing"); } } @@ -809,9 +810,10 @@ impl WindowUDFImpl for MetadataBasedWindowUdf { Ok(Box::new(MetadataBasedPartitionEvaluator { double_output })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { Ok(Field::new(field_args.name(), DataType::UInt64, true) - .with_metadata(self.metadata.clone())) + .with_metadata(self.metadata.clone()) + .into()) } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 7da4e938f5dd6..e9377ce7de5a2 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -17,7 +17,7 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ - DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; @@ -89,7 +89,7 @@ pub static TIMES: &[DataType] = &[ /// number of input types. pub fn check_arg_count( func_name: &str, - input_fields: &[Field], + input_fields: &[FieldRef], signature: &TypeSignature, ) -> Result<()> { match signature { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a081a5430d409..fe5ea2ecd5b8b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,7 @@ use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -846,10 +846,10 @@ impl WindowFunctionDefinition { /// Returns the datatype of the window function pub fn return_field( &self, - input_expr_fields: &[Field], + input_expr_fields: &[FieldRef], _input_expr_nullable: &[bool], display_name: &str, - ) -> Result { + ) -> Result { match self { WindowFunctionDefinition::AggregateUDF(fun) => { fun.return_field(input_expr_fields) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cee356a2b42cf..67e80a8d9bba9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -37,7 +37,7 @@ use crate::{ use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -492,6 +492,7 @@ pub fn create_udaf( .into_iter() .enumerate() .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(Arc::new) .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -510,7 +511,7 @@ pub struct SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -533,7 +534,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -553,7 +554,7 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { @@ -590,7 +591,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } @@ -678,12 +679,12 @@ impl WindowUDFImpl for SimpleWindowUDF { (self.partition_evaluator_factory)() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new( + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Arc::new(Field::new( field_args.name(), self.return_type.clone(), true, - )) + ))) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6022182bfe67f..bdf9911b006c7 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -27,7 +27,7 @@ use crate::type_coercion::functions::{ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -160,10 +160,10 @@ impl ExprSchemable for Expr { }) => { let fields = args .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_fields = - fields_with_aggregate_udf(&fields, func).map_err(|err| { + let new_fields = fields_with_aggregate_udf(&fields, func) + .map_err(|err| { let data_types = fields .iter() .map(|f| f.data_type().clone()) @@ -180,7 +180,9 @@ impl ExprSchemable for Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); Ok(func.return_field(&new_fields)?.data_type().clone()) } Expr::Not(_) @@ -408,17 +410,21 @@ impl ExprSchemable for Expr { } } - Ok(field.with_metadata(combined_metadata)) - } - Expr::Negative(expr) => { - expr.to_field(schema).map(|(_, f)| f.as_ref().clone()) + Ok(Arc::new(field.with_metadata(combined_metadata))) } - Expr::Column(c) => schema.field_from_column(c).cloned(), + Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), + Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())), Expr::OuterReferenceColumn(ty, _) => { - Ok(Field::new(&schema_name, ty.clone(), true)) + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) } - Expr::ScalarVariable(ty, _) => Ok(Field::new(&schema_name, ty.clone(), true)), - Expr::Literal(l) => Ok(Field::new(&schema_name, l.data_type(), l.is_null())), + Expr::ScalarVariable(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::Literal(l) => Ok(Arc::new(Field::new( + &schema_name, + l.data_type(), + l.is_null(), + ))), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -428,10 +434,10 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => { - Ok(Field::new(&schema_name, DataType::Boolean, false)) + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false))) } Expr::ScalarSubquery(subquery) => { - Ok(subquery.subquery.schema().field(0).clone()) + Ok(Arc::new(subquery.subquery.schema().field(0).clone())) } Expr::BinaryExpr(BinaryExpr { ref left, @@ -443,18 +449,18 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok(Field::new( + Ok(Arc::new(Field::new( &schema_name, coercer.get_result_type()?, lhs_nullable || rhs_nullable, - )) + ))) } Expr::WindowFunction(window_function) => { let (dt, nullable) = self.data_type_and_nullable_with_window_function( schema, window_function, )?; - Ok(Field::new(&schema_name, dt, nullable)) + Ok(Arc::new(Field::new(&schema_name, dt, nullable))) } Expr::AggregateFunction(aggregate_function) => { let AggregateFunction { @@ -465,11 +471,11 @@ impl ExprSchemable for Expr { let fields = args .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = - fields_with_aggregate_udf(&fields, func).map_err(|err| { + let new_fields = fields_with_aggregate_udf(&fields, func) + .map_err(|err| { let arg_types = fields .iter() .map(|f| f.data_type()) @@ -487,7 +493,9 @@ impl ExprSchemable for Expr { &arg_types, ) ) - })?; + })? + .into_iter() + .collect::>(); func.return_field(&new_fields) } @@ -519,7 +527,8 @@ impl ExprSchemable for Expr { .into_iter() .zip(new_data_types) .map(|(f, d)| f.as_ref().clone().with_data_type(d)) - .collect::>(); + .map(Arc::new) + .collect::>(); let arguments = args .iter() @@ -538,7 +547,8 @@ impl ExprSchemable for Expr { // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), Expr::Cast(Cast { expr, data_type }) => expr .to_field(schema) - .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())), + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(Arc::new), Expr::Like(_) | Expr::SimilarTo(_) | Expr::Not(_) @@ -550,14 +560,17 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Field::new( + | Expr::Unnest(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, - )), + ))), }?; - Ok((relation, Arc::new(field.with_name(schema_name)))) + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -612,7 +625,7 @@ impl Expr { let fields = args .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { @@ -621,8 +634,8 @@ impl Expr { .map(|f| f.data_type()) .cloned() .collect::>(); - let new_fields = - fields_with_aggregate_udf(&fields, udaf).map_err(|err| { + let new_fields = fields_with_aggregate_udf(&fields, udaf) + .map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -635,7 +648,9 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); let return_field = udaf.return_field(&new_fields)?; @@ -647,8 +662,8 @@ impl Expr { .map(|f| f.data_type()) .cloned() .collect::>(); - let new_fields = - fields_with_window_udf(&fields, udwf).map_err(|err| { + let new_fields = fields_with_window_udf(&fields, udwf) + .map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -661,7 +676,9 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); let (_, function_name) = self.qualified_name(); let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index a753f4c376c63..673908a4d7e7d 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,7 +22,7 @@ use std::any::Any; use arrow::datatypes::{ - DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; @@ -175,7 +175,7 @@ impl AggregateUDFImpl for Sum { unreachable!("stub should not have accumulate()") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unreachable!("stub should not have state_fields()") } @@ -254,7 +254,7 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -336,7 +336,7 @@ impl AggregateUDFImpl for Min { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -421,7 +421,7 @@ impl AggregateUDFImpl for Max { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -491,7 +491,7 @@ impl AggregateUDFImpl for Avg { not_impl_err!("no impl for stub") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } fn aliases(&self) -> &[String] { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 6d1ed238646d3..763a4e6539fd8 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -17,7 +17,7 @@ use super::binary::binary_numeric_coercion; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, @@ -84,9 +84,9 @@ pub fn data_types_with_scalar_udf( /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. pub fn fields_with_aggregate_udf( - current_fields: &[Field], + current_fields: &[FieldRef], func: &AggregateUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; @@ -121,7 +121,10 @@ pub fn fields_with_aggregate_udf( Ok(current_fields .iter() .zip(updated_types) - .map(|(current_field, new_type)| current_field.clone().with_data_type(new_type)) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) .collect()) } @@ -133,9 +136,9 @@ pub fn fields_with_aggregate_udf( /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. pub fn fields_with_window_udf( - current_fields: &[Field], + current_fields: &[FieldRef], func: &WindowUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; @@ -170,7 +173,10 @@ pub fn fields_with_window_udf( Ok(current_fields .iter() .zip(updated_types) - .map(|(current_field, new_type)| current_field.clone().with_data_type(new_type)) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) .collect()) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3a8d0253a3892..d1bf45ce2fe8a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,7 +24,7 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -227,7 +227,7 @@ impl AggregateUDF { /// Return the field of the function given its input fields /// /// See [`AggregateUDFImpl::return_field`] for more details. - pub fn return_field(&self, args: &[Field]) -> Result { + pub fn return_field(&self, args: &[FieldRef]) -> Result { self.inner.return_field(args) } @@ -241,7 +241,7 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -363,8 +363,8 @@ where /// # Basic Example /// ``` /// # use std::any::Any; -/// # use std::sync::LazyLock; -/// # use arrow::datatypes::DataType; +/// # use std::sync::{Arc, LazyLock}; +/// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -408,10 +408,10 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// args.return_field.clone().with_name("value"), -/// Field::new("ordering", DataType::UInt32, true) +/// Arc::new(args.return_field.as_ref().clone().with_name("value")), +/// Arc::new(Field::new("ordering", DataType::UInt32, true)) /// ]) /// } /// fn documentation(&self) -> Option<&Documentation> { @@ -698,12 +698,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// 2. return types based on the **values** of the arguments (rather than /// their **types**. /// 3. return types based on metadata within the fields of the inputs - fn return_field(&self, arg_fields: &[Field]) -> Result { + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect(); let data_type = self.return_type(&arg_types)?; - Ok(Field::new(self.name(), data_type, self.is_nullable())) + Ok(Arc::new(Field::new( + self.name(), + data_type, + self.is_nullable(), + ))) } /// Whether the aggregate function is nullable. @@ -744,14 +748,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let fields = vec![args .return_field + .as_ref() .clone() .with_name(format_state_name(args.name, "value"))]; Ok(fields .into_iter() + .map(Arc::new) .chain(args.ordering_fields.to_vec()) .collect()) } @@ -1045,7 +1051,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { &self.aliases } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -1178,7 +1184,7 @@ pub enum SetMonotonicity { #[cfg(test)] mod test { use crate::{AggregateUDF, AggregateUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::signature::{Signature, Volatility}; @@ -1224,7 +1230,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } @@ -1264,7 +1270,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3983d17516457..816929a1fba17 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -181,7 +181,7 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// /// See [`ScalarUDFImpl::return_field_from_args`] for more details. - pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { self.inner.return_field_from_args(args) } @@ -293,20 +293,20 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a, 'b> { +pub struct ScalarFunctionArgs { /// The evaluated arguments to the function pub args: Vec, /// Field associated with each arg, if it exists - pub arg_fields: Vec<&'a Field>, + pub arg_fields: Vec, /// The number of rows in record batch being evaluated pub number_rows: usize, /// The return field of the scalar function returned (from `return_type` /// or `return_field_from_args`) when creating the physical expression /// from the logical expression - pub return_field: &'b Field, + pub return_field: FieldRef, } -impl<'a, 'b> ScalarFunctionArgs<'a, 'b> { +impl ScalarFunctionArgs { /// The return type of the function. See [`Self::return_field`] for more /// details. pub fn return_type(&self) -> &DataType { @@ -324,7 +324,7 @@ impl<'a, 'b> ScalarFunctionArgs<'a, 'b> { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_fields: &'a [Field], + pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// /// If the argument `i` is not a scalar, it will be None @@ -476,15 +476,16 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// `DataType::Struct`. /// /// ```rust - /// # use arrow::datatypes::{DataType, Field}; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::Result; /// # use datafusion_expr::ReturnFieldArgs; - /// # struct Example{}; + /// # struct Example{} /// # impl Example { - /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { /// // report output is only nullable if any one of the arguments are nullable /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - /// let field = Field::new("ignored_name", DataType::Int32, true); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); /// Ok(field) /// } /// # } @@ -504,7 +505,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let data_types = args .arg_fields .iter() @@ -512,7 +513,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { .cloned() .collect::>(); let return_type = self.return_type(&data_types)?; - Ok(Field::new(self.name(), return_type, true)) + Ok(Arc::new(Field::new(self.name(), return_type, true))) } #[deprecated( @@ -766,7 +767,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { self.inner.return_field_from_args(args) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a52438fcc99cc..c0187735d6025 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -26,7 +26,7 @@ use std::{ sync::Arc, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; use crate::{ @@ -179,7 +179,7 @@ impl WindowUDF { /// Returns the field of the final result of evaluating this window function. /// /// See [`WindowUDFImpl::field`] for more details. - pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -236,7 +236,7 @@ where /// ``` /// # use std::any::Any; /// # use std::sync::LazyLock; -/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; @@ -279,9 +279,9 @@ where /// ) -> Result> { /// unimplemented!() /// } -/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { +/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { /// if let Some(DataType::Int32) = field_args.get_input_field(0).map(|f| f.data_type().clone()) { -/// Ok(Field::new(field_args.name(), DataType::Int32, false)) +/// Ok(Field::new(field_args.name(), DataType::Int32, false).into()) /// } else { /// plan_err!("smooth_it only accepts Int32 arguments") /// } @@ -386,12 +386,12 @@ pub trait WindowUDFImpl: Debug + Send + Sync { hasher.finish() } - /// The [`Field`] of the final result of evaluating this window function. + /// The [`FieldRef`] of the final result of evaluating this window function. /// /// Call `field_args.name()` to get the fully qualified name for defining - /// the [`Field`]. For a complete example see the implementation in the + /// the [`FieldRef`]. For a complete example see the implementation in the /// [Basic Example](WindowUDFImpl#basic-example) section. - fn field(&self, field_args: WindowUDFFieldArgs) -> Result; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. /// @@ -537,7 +537,7 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { hasher.finish() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -588,7 +588,7 @@ pub mod window_doc_sections { #[cfg(test)] mod test { use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -630,7 +630,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } @@ -669,7 +669,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index c35a53205eb4f..303acc783b2e4 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -31,6 +31,7 @@ use arrow::{ error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use arrow_schema::FieldRef; use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, @@ -149,7 +150,7 @@ unsafe extern "C" fn return_field_from_args_fn_wrapper( let return_type = udf .return_field_from_args((&args_ref).into()) - .and_then(|f| FFI_ArrowSchema::try_from(f).map_err(DataFusionError::from)) + .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from)) .map(WrappedSchema); rresult!(return_type) @@ -188,20 +189,23 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_field = rresult_return!(Field::try_from(&return_field.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); - let arg_fields_owned = arg_fields + let arg_fields = arg_fields .into_iter() - .map(|wrapped_field| (&wrapped_field.0).try_into().map_err(DataFusionError::from)) - .collect::>>(); - let arg_fields_owned = rresult_return!(arg_fields_owned); - let arg_fields = arg_fields_owned.iter().collect::>(); + .map(|wrapped_field| { + Field::try_from(&wrapped_field.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) + .collect::>>(); + let arg_fields = rresult_return!(arg_fields); let args = ScalarFunctionArgs { args, arg_fields, number_rows, - return_field: &return_field, + return_field, }; let result = rresult_return!(udf @@ -323,14 +327,18 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let args: FFI_ReturnFieldArgs = args.try_into()?; let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + result.and_then(|r| { + Field::try_from(&r.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { @@ -357,7 +365,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { let arg_fields_wrapped = arg_fields .iter() - .map(|field| FFI_ArrowSchema::try_from(*field)) + .map(FFI_ArrowSchema::try_from) .collect::, ArrowError>>()?; let arg_fields = arg_fields_wrapped @@ -365,6 +373,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { .map(WrappedSchema) .collect::>(); + let return_field = return_field.as_ref().clone(); let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index 40e577591c340..c437c9537be6f 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,14 +19,14 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::Field; +use arrow_schema::FieldRef; use datafusion::{ common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; use crate::arrow_wrappers::WrappedSchema; -use crate::util::{rvec_wrapped_to_vec_field, vec_field_to_rvec_wrapped}; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use prost::Message; /// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. @@ -42,7 +42,7 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { - let arg_fields = vec_field_to_rvec_wrapped(value.arg_fields)?; + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -70,12 +70,12 @@ impl TryFrom> for FFI_ReturnFieldArgs { // appears a restriction based on the need to have a borrowed ScalarValue // in the arguments when converted to ReturnFieldArgs pub struct ForeignReturnFieldArgsOwned { - arg_fields: Vec, + arg_fields: Vec, scalar_arguments: Vec>, } pub struct ForeignReturnFieldArgs<'a> { - arg_fields: &'a [Field], + arg_fields: &'a [FieldRef], scalar_arguments: Vec>, } @@ -83,7 +83,7 @@ impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; fn try_from(value: &FFI_ReturnFieldArgs) -> Result { - let arg_fields = rvec_wrapped_to_vec_field(&value.arg_fields)?; + let arg_fields = rvec_wrapped_to_vec_fieldref(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index f199d1523cbcc..3eb57963b44f8 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -19,6 +19,8 @@ use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use std::sync::Arc; /// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a /// DataFusion result. @@ -66,8 +68,8 @@ macro_rules! rresult_return { /// This is a utility function to convert a slice of [`Field`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] -pub fn vec_field_to_rvec_wrapped( - fields: &[Field], +pub fn vec_fieldref_to_rvec_wrapped( + fields: &[FieldRef], ) -> Result, arrow::error::ArrowError> { Ok(fields .iter() @@ -80,10 +82,13 @@ pub fn vec_field_to_rvec_wrapped( /// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] /// to their equivalent [`Field`]. -pub fn rvec_wrapped_to_vec_field( +pub fn rvec_wrapped_to_vec_fieldref( fields: &RVec, -) -> Result, arrow::error::ArrowError> { - fields.iter().map(|d| Field::try_from(&d.0)).collect() +) -> Result, arrow::error::ArrowError> { + fields + .iter() + .map(|d| Field::try_from(&d.0).map(Arc::new)) + .collect() } /// This is a utility function to convert a slice of [`DataType`] to its equivalent diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index eba4f6b70d2bc..01b16f1b0a8cc 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -28,7 +28,7 @@ use std::sync::Arc; #[derive(Debug)] pub struct AccumulatorArgs<'a> { /// The return field of the aggregate function. - pub return_field: &'a Field, + pub return_field: FieldRef, /// The schema of the input arguments pub schema: &'a Schema, @@ -89,13 +89,13 @@ pub struct StateFieldsArgs<'a> { pub name: &'a str, /// The input fields of the aggregate function. - pub input_fields: &'a [Field], + pub input_fields: &'a [FieldRef], /// The return fields of the aggregate function. - pub return_field: &'a Field, + pub return_field: FieldRef, /// The ordering fields of the aggregate function. - pub ordering_fields: &'a [Field], + pub ordering_fields: &'a [FieldRef], /// Whether the aggregate function is distinct. pub is_distinct: bool, diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 083dac615b5d1..229d9a900105a 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::{ array::ArrowNativeTypeOp, compute::SortOptions, @@ -92,7 +92,7 @@ pub fn ordering_fields( ordering_req: &LexOrdering, // Data type of each expression in the ordering requirement data_types: &[DataType], -) -> Vec { +) -> Vec { ordering_req .iter() .zip(data_types.iter()) @@ -104,6 +104,7 @@ pub fn ordering_fields( true, ) }) + .map(Arc::new) .collect() } diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index fc7561dd8a569..d5abf6b8ac281 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -28,7 +28,7 @@ use std::sync::Arc; fn prepare_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { - return_field: &Field::new("f", DataType::Int64, true), + return_field: Field::new("f", DataType::Int64, true).into(), schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index d05d5c5676c5d..25df78b15f11c 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -26,10 +26,10 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; fn prepare_accumulator(data_type: &DataType) -> Box { - let field = Field::new("f", data_type.clone(), true); - let schema = Arc::new(Schema::new(vec![field.clone()])); + let field = Field::new("f", data_type.clone(), true).into(); + let schema = Arc::new(Schema::new(vec![Arc::clone(&field)])); let accumulator_args = AccumulatorArgs { - return_field: &field, + return_field: field, schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index c97dba1925ca9..0d5dcd5c2085a 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -23,7 +23,7 @@ use arrow::array::{ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + ArrowPrimitiveType, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; @@ -322,12 +322,13 @@ impl AggregateUDFImpl for ApproxDistinct { Ok(DataType::UInt64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "hll_registers"), DataType::Binary, false, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 9a202879d94ab..0f2e3039ca9f1 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -17,11 +17,11 @@ //! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution +use arrow::datatypes::DataType::{Float64, UInt64}; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use std::fmt::Debug; - -use arrow::datatypes::DataType::{Float64, UInt64}; -use arrow::datatypes::{DataType, Field}; +use std::sync::Arc; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -91,7 +91,7 @@ impl AggregateUDFImpl for ApproxMedian { self } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), @@ -103,7 +103,10 @@ impl AggregateUDFImpl for ApproxMedian { Field::new_list_field(Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 41281733f5deb..024c0a823fa9e 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -29,7 +30,6 @@ use arrow::{ }, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result, ScalarValue, @@ -256,7 +256,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "max_size"), @@ -288,7 +288,10 @@ impl AggregateUDFImpl for ApproxPercentileCont { Field::new_list_field(DataType::Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 0316757f26d08..5180d45889620 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -20,11 +20,8 @@ use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; use std::sync::Arc; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Field}, -}; - +use arrow::datatypes::FieldRef; +use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -174,7 +171,7 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 54479ee99fc3a..71278767a83fc 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -21,7 +21,7 @@ use arrow::array::{ new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, }; use arrow::compute::{filter, SortOptions}; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; @@ -109,14 +109,15 @@ impl AggregateUDFImpl for ArrayAgg { )))) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]); + ) + .into()]); } let mut fields = vec![Field::new_list( @@ -124,18 +125,22 @@ impl AggregateUDFImpl for ArrayAgg { // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]; + ) + .into()]; if args.ordering_fields.is_empty() { return Ok(fields); } let orderings = args.ordering_fields.to_vec(); - fields.push(Field::new_list( - format_state_name(args.name, "array_agg_orderings"), - Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), - false, - )); + fields.push( + Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), + false, + ) + .into(), + ); Ok(fields) } @@ -691,7 +696,6 @@ impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let num_columns = fields.len(); - let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; for i in 0..num_columns { @@ -708,6 +712,7 @@ impl OrderSensitiveArrayAggAccumulator { column_wise_ordering_values.push(array); } + let struct_field = Fields::from(fields); let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) @@ -717,7 +722,7 @@ impl OrderSensitiveArrayAggAccumulator { #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::Schema; + use arrow::datatypes::{FieldRef, Schema}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; @@ -984,7 +989,7 @@ mod tests { } struct ArrayAggAccumulatorBuilder { - return_field: Field, + return_field: FieldRef, distinct: bool, ordering: LexOrdering, schema: Schema, @@ -997,7 +1002,7 @@ mod tests { fn new(data_type: DataType) -> Self { Self { - return_field: Field::new("f", data_type.clone(), true), + return_field: Field::new("f", data_type.clone(), true).into(), distinct: false, ordering: Default::default(), schema: Schema { @@ -1029,7 +1034,7 @@ mod tests { fn build(&self) -> Result> { ArrayAgg::default().accumulator(AccumulatorArgs { - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 3ca39aa315892..3c1d33e093b50 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -26,7 +26,7 @@ use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, Float64Type, TimeUnit, UInt64Type, + DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -164,7 +164,7 @@ impl AggregateUDFImpl for Avg { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -176,7 +176,10 @@ impl AggregateUDFImpl for Avg { args.input_fields[0].data_type().clone(), true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index aeaeefcd7a72c..4512162ba5d33 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -25,8 +25,8 @@ use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_list_array; @@ -263,7 +263,7 @@ impl AggregateUDFImpl for BitwiseOperation { downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if self.operation == BitwiseOperationType::Xor && args.is_distinct { Ok(vec![Field::new_list( format_state_name( @@ -273,13 +273,15 @@ impl AggregateUDFImpl for BitwiseOperation { // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, self.name()), args.return_field.data_type().clone(), true, - )]) + ) + .into()]) } } diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 034a28c27bb7f..e5de6d76217fb 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -24,8 +24,8 @@ use arrow::array::ArrayRef; use arrow::array::BooleanArray; use arrow::compute::bool_and as compute_bool_and; use arrow::compute::bool_or as compute_bool_or; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; @@ -150,12 +150,13 @@ impl AggregateUDFImpl for BoolAnd { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -288,12 +289,13 @@ impl AggregateUDFImpl for BoolOr { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index ac57256ce882f..0a7345245ca8c 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -27,7 +27,7 @@ use arrow::array::{ UInt64Array, }; use arrow::compute::{and, filter, is_not_null, kernels::cast}; -use arrow::datatypes::{Float64Type, UInt64Type}; +use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, @@ -117,7 +117,7 @@ impl AggregateUDFImpl for Correlation { Ok(Box::new(CorrelationAccumulator::try_new()?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -130,7 +130,10 @@ impl AggregateUDFImpl for Correlation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 42078c7355783..eccd0cd05187b 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -40,6 +40,7 @@ use arrow::{ }, }; +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, BooleanArray, Int64Array, PrimitiveArray}, buffer::BooleanBuffer, @@ -201,20 +202,22 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "count"), DataType::Int64, false, - )]) + ) + .into()]) } } diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d4ae27533c6db..9f37a73e5429e 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,15 +17,12 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use std::fmt::Debug; -use std::mem::size_of_val; - +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, compute::kernels::cast, datatypes::{DataType, Field}, }; - use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, @@ -38,6 +35,9 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::Arc; make_udaf_expr_and_func!( CovarianceSample, @@ -120,7 +120,7 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -131,7 +131,10 @@ impl AggregateUDFImpl for CovarianceSample { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -210,7 +213,7 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -221,7 +224,10 @@ impl AggregateUDFImpl for CovariancePopulation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8264d5fa74cb5..e8022245dba55 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -29,8 +29,8 @@ use arrow::array::{ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, @@ -169,14 +169,15 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), args.return_type().clone(), true, - )]; + ) + .into()]; fields.extend(args.ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } @@ -1046,7 +1047,7 @@ impl AggregateUDFImpl for LastValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let StateFieldsArgs { name, input_fields, @@ -1058,9 +1059,10 @@ impl AggregateUDFImpl for LastValue { format_state_name(name, "last_value"), input_fields[0].data_type().clone(), true, - )]; + ) + .into()]; fields.extend(ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 445774ff11e7d..0727cf33036a0 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -20,8 +20,8 @@ use std::any::Any; use std::fmt; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -105,12 +105,13 @@ impl AggregateUDFImpl for Grouping { Ok(DataType::Int32) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "grouping"), DataType::Int32, true, - )]) + ) + .into()]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 3d3f385033595..bfaea4b2398cc 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -35,7 +35,7 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef}; use datafusion_common::{ internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, @@ -125,7 +125,7 @@ impl AggregateUDFImpl for Median { Ok(arg_types[0].clone()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); let state_name = if args.is_distinct { @@ -138,7 +138,8 @@ impl AggregateUDFImpl for Median { format_state_name(args.name, state_name), DataType::List(Arc::new(field)), true, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 8a7c721dd4724..1525b2f991a1f 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -24,7 +24,7 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; @@ -164,7 +164,7 @@ impl AggregateUDFImpl for NthValueAgg { .map(|acc| Box::new(acc) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true @@ -179,7 +179,7 @@ impl AggregateUDFImpl for NthValueAgg { false, )); } - Ok(fields) + Ok(fields.into_iter().map(Arc::new).collect()) } fn aliases(&self) -> &[String] { @@ -400,7 +400,6 @@ impl Accumulator for NthValueAccumulator { impl NthValueAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; let num_columns = fields.len(); @@ -418,6 +417,7 @@ impl NthValueAccumulator { column_wise_ordering_values.push(array); } + let struct_field = Fields::from(fields); let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 82575d15e50b8..0f84aa1323f52 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, UInt64Array}, compute::cast, @@ -38,7 +39,7 @@ use datafusion_expr::{ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -278,7 +279,7 @@ impl AggregateUDFImpl for Regr { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -310,7 +311,10 @@ impl AggregateUDFImpl for Regr { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 5d3a6d5f70a7a..f948df840e73b 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -23,8 +23,8 @@ use std::mem::align_of_val; use std::sync::Arc; use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -109,7 +109,7 @@ impl AggregateUDFImpl for Stddev { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -122,7 +122,10 @@ impl AggregateUDFImpl for Stddev { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -217,7 +220,7 @@ impl AggregateUDFImpl for StddevPop { &self.signature } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -230,7 +233,10 @@ impl AggregateUDFImpl for StddevPop { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -436,7 +442,7 @@ mod tests { schema: &Schema, ) -> Result { let args1 = AccumulatorArgs { - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), @@ -447,7 +453,7 @@ mod tests { }; let args2 = AccumulatorArgs { - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index d59f8a576e784..3f7e503acfced 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -19,7 +19,7 @@ use crate::array_agg::ArrayAgg; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{internal_err, not_impl_err, ScalarValue}; @@ -129,7 +129,7 @@ impl AggregateUDFImpl for StringAgg { Ok(DataType::LargeUtf8) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.array_agg.state_fields(args) } @@ -154,11 +154,12 @@ impl AggregateUDFImpl for StringAgg { }; let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { - return_field: &Field::new( + return_field: Field::new( "f", DataType::new_list(acc_args.return_field.data_type().clone(), true), true, - ), + ) + .into(), exprs: &filter_index(acc_args.exprs, 1), ..acc_args })?; @@ -440,7 +441,7 @@ mod tests { fn build(&self) -> Result> { StringAgg::new().accumulator(AccumulatorArgs { - return_field: &Field::new("f", DataType::LargeUtf8, true), + return_field: Field::new("f", DataType::LargeUtf8, true).into(), schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index aaa0c4b94a7f5..37d208ffb03ad 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -26,8 +26,8 @@ use std::mem::{size_of, size_of_val}; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -201,20 +201,22 @@ impl AggregateUDFImpl for Sum { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), args.return_type().clone(), true, - )]) + ) + .into()]) } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 53e3e0cc56cd2..586b2dab0ae6b 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,15 +18,13 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; -use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; - use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, @@ -38,6 +36,8 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::mem::{size_of, size_of_val}; +use std::{fmt::Debug, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -107,13 +107,16 @@ impl AggregateUDFImpl for VarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -200,13 +203,16 @@ impl AggregateUDFImpl for VariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index e861de95d8637..a752a47bcbaa5 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -97,19 +97,20 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); - let return_field = &Field::new("f", return_type, true); + let arg_fields = vec![ + Field::new("a", keys.data_type(), true).into(), + Field::new("a", values.data_type(), true).into(), + ]; + let return_field = Field::new("f", return_type, true).into(); b.iter(|| { black_box( map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], - arg_fields: vec![ - &Field::new("a", keys.data_type(), true), - &Field::new("a", values.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: 1, - return_field, + return_field: Arc::clone(&return_field), }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index bd313aa9ed48e..8247fdd4a74ce 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -99,7 +99,7 @@ impl ScalarUDFImpl for MapValuesFunc { fn return_field_from_args( &self, args: datafusion_expr::ReturnFieldArgs, - ) -> Result { + ) -> Result { let [map_type] = take_function_args(self.name(), args.arg_fields)?; Ok(Field::new( @@ -107,7 +107,8 @@ impl ScalarUDFImpl for MapValuesFunc { DataType::List(get_map_values_field_as_list_field(map_type.data_type())?), // Nullable if the map is nullable args.arg_fields.iter().any(|x| x.is_nullable()), - )) + ) + .into()) } fn invoke_with_args( @@ -154,7 +155,7 @@ fn get_map_values_field_as_list_field(map_type: &DataType) -> Result { #[cfg(test)] mod tests { use crate::map_values::MapValuesFunc; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::ScalarValue; use datafusion_expr::ScalarUDFImpl; use std::sync::Arc; @@ -165,7 +166,7 @@ mod tests { is_map_nullable: bool, is_keys_nullable: bool, is_values_nullable: bool, - ) -> Field { + ) -> FieldRef { Field::new_map( "something", "entries", @@ -178,6 +179,7 @@ mod tests { false, is_map_nullable, ) + .into() } fn get_list_field( @@ -185,7 +187,7 @@ mod tests { is_list_nullable: bool, list_item_type: DataType, is_list_items_nullable: bool, - ) -> Field { + ) -> FieldRef { Field::new_list( name, Arc::new(Field::new_list_field( @@ -194,9 +196,10 @@ mod tests { )), is_list_nullable, ) + .into() } - fn get_return_field(field: Field) -> Field { + fn get_return_field(field: FieldRef) -> FieldRef { let func = MapValuesFunc::new(); let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index eb2516a1e5569..774cd5182b30b 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::Field; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -27,7 +27,7 @@ pub struct ExpressionArgs<'a> { input_exprs: &'a [Arc], /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_fields: &'a [Field], + input_fields: &'a [FieldRef], } impl<'a> ExpressionArgs<'a> { @@ -42,7 +42,7 @@ impl<'a> ExpressionArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_fields: &'a [Field], + input_fields: &'a [FieldRef], ) -> Self { Self { input_exprs, @@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> { self.input_exprs } - /// Returns the [`Field`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_fields(&self) -> &'a [Field] { + pub fn input_fields(&self) -> &'a [FieldRef] { self.input_fields } } diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs index 9e1898908c957..8d22efa3bcf44 100644 --- a/datafusion/functions-window-common/src/field.rs +++ b/datafusion/functions-window-common/src/field.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::Field; +use datafusion_common::arrow::datatypes::FieldRef; /// Metadata for defining the result field from evaluating a /// user-defined window function. pub struct WindowUDFFieldArgs<'a> { /// The fields corresponding to the arguments to the /// user-defined window function. - input_fields: &'a [Field], + input_fields: &'a [FieldRef], /// The display name of the user-defined window function. display_name: &'a str, } @@ -37,7 +37,7 @@ impl<'a> WindowUDFFieldArgs<'a> { /// * `function_name` - The qualified schema name of the /// user-defined window function expression. /// - pub fn new(input_fields: &'a [Field], display_name: &'a str) -> Self { + pub fn new(input_fields: &'a [FieldRef], display_name: &'a str) -> Self { WindowUDFFieldArgs { input_fields, display_name, @@ -46,7 +46,7 @@ impl<'a> WindowUDFFieldArgs<'a> { /// Returns the field of input expressions passed as arguments /// to the user-defined window function. - pub fn input_fields(&self) -> &[Field] { + pub fn input_fields(&self) -> &[FieldRef] { self.input_fields } @@ -58,7 +58,7 @@ impl<'a> WindowUDFFieldArgs<'a> { /// Returns `Some(Field)` of input expression at index, otherwise /// returns `None` if the index is out of bounds. - pub fn get_input_field(&self, index: usize) -> Option { + pub fn get_input_field(&self, index: usize) -> Option { self.input_fields.get(index).cloned() } } diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index 64c28e61a2cde..61125e596130b 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::Field; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -28,7 +28,7 @@ pub struct PartitionEvaluatorArgs<'a> { input_exprs: &'a [Arc], /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_fields: &'a [Field], + input_fields: &'a [FieldRef], /// Set to `true` if the user-defined window function is reversed. is_reversed: bool, /// Set to `true` if `IGNORE NULLS` is specified. @@ -51,7 +51,7 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_fields: &'a [Field], + input_fields: &'a [FieldRef], is_reversed: bool, ignore_nulls: bool, ) -> Self { @@ -69,9 +69,9 @@ impl<'a> PartitionEvaluatorArgs<'a> { self.input_exprs } - /// Returns the [`Field`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_fields(&self) -> &'a [Field] { + pub fn input_fields(&self) -> &'a [FieldRef] { self.input_fields } diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index e0c17c579b196..23ee608a82675 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -38,6 +38,7 @@ workspace = true name = "datafusion_functions_window" [dependencies] +arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } @@ -47,6 +48,3 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" - -[dev-dependencies] -arrow = { workspace = true } diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index d156416a82a4b..ed8669948188d 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -17,6 +17,7 @@ //! `cume_dist` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, Float64Array}; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -101,8 +102,8 @@ impl WindowUDFImpl for CumeDist { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 6ebbceaced5ea..e2a755371ebc8 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -18,6 +18,7 @@ //! `lead` and `lag` window function implementations use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -274,10 +275,14 @@ impl WindowUDFImpl for WindowShift { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_field = parse_expr_field(field_args.input_fields())?; - Ok(return_field.with_name(field_args.name())) + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -309,7 +314,7 @@ impl WindowUDFImpl for WindowShift { /// For more details see: fn parse_expr( input_exprs: &[Arc], - input_fields: &[Field], + input_fields: &[FieldRef], ) -> Result> { assert!(!input_exprs.is_empty()); assert!(!input_fields.is_empty()); @@ -331,31 +336,35 @@ fn parse_expr( }) } -static NULL_FIELD: LazyLock = - LazyLock::new(|| Field::new("value", DataType::Null, true)); +static NULL_FIELD: LazyLock = + LazyLock::new(|| Field::new("value", DataType::Null, true).into()); /// Returns the field of the default value(if provided) when the /// expression is `NULL`. /// /// Otherwise, returns the expression field unchanged. -fn parse_expr_field(input_fields: &[Field]) -> Result { +fn parse_expr_field(input_fields: &[FieldRef]) -> Result { assert!(!input_fields.is_empty()); let expr_field = input_fields.first().unwrap_or(&NULL_FIELD); // Handles the most common case where NULL is unexpected if !expr_field.data_type().is_null() { - return Ok(expr_field.clone().with_nullable(true)); + return Ok(expr_field.as_ref().clone().with_nullable(true).into()); } let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD); - Ok(default_value_field.clone().with_nullable(true)) + Ok(default_value_field + .as_ref() + .clone() + .with_nullable(true) + .into()) } /// Handles type coercion and null value refinement for default value /// argument depending on the data type of the input expression. fn parse_default_value( input_exprs: &[Arc], - input_types: &[Field], + input_types: &[FieldRef], ) -> Result { let expr_field = parse_expr_field(input_types)?; let unparsed = get_scalar_value_from_args(input_exprs, 2)?; @@ -710,7 +719,7 @@ mod tests { WindowShift::lead(), PartitionEvaluatorArgs::new( &[expr], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -737,7 +746,7 @@ mod tests { WindowShift::lag(), PartitionEvaluatorArgs::new( &[expr], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -768,6 +777,7 @@ mod tests { let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32] .into_iter() .map(|d| Field::new("f", d, true)) + .map(Arc::new) .collect::>(); test_i32_result( diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 27799140931e0..23414a7a7172a 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -40,6 +40,7 @@ /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -85,8 +86,8 @@ /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -138,6 +139,7 @@ macro_rules! get_or_init_udwf { /// 1. With Zero Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; @@ -196,8 +198,8 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -205,6 +207,7 @@ macro_rules! get_or_init_udwf { /// 2. With Multiple Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -283,12 +286,12 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), /// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -352,6 +355,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -404,8 +408,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -415,6 +419,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; @@ -468,8 +473,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -479,6 +484,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -554,12 +560,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), /// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -567,6 +573,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -643,12 +650,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), /// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index b2ecc87f4be84..0b83e1ff9f084 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,12 +19,7 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::LazyLock; - +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; @@ -37,6 +32,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::LazyLock; get_or_init_udwf!( First, @@ -309,7 +309,7 @@ impl WindowUDFImpl for NthValue { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_type = field_args .input_fields() .first() @@ -317,7 +317,7 @@ impl WindowUDFImpl for NthValue { .cloned() .unwrap_or(DataType::Null); - Ok(Field::new(field_args.name(), return_type, true)) + Ok(Field::new(field_args.name(), return_type, true).into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -557,7 +557,7 @@ mod tests { NthValue::first(), PartitionEvaluatorArgs::new( &[expr], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -572,7 +572,7 @@ mod tests { NthValue::last(), PartitionEvaluatorArgs::new( &[expr], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -599,7 +599,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -618,7 +618,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[Field::new("f", DataType::Int32, true)], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index d2e6fadb002ee..6b4c0960e695c 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -17,13 +17,10 @@ //! `ntile` window function implementation -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - use crate::utils::{ get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, }; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, DataFusionError, Result}; @@ -34,6 +31,9 @@ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; get_or_init_udwf!( Ntile, @@ -149,10 +149,10 @@ impl WindowUDFImpl for Ntile { Ok(Box::new(NtileEvaluator { n: n as u64 })) } } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let nullable = false; - Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + Ok(Field::new(field_args.name(), DataType::UInt64, nullable).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index e814a9691f4fe..969a957cddd9c 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -18,13 +18,8 @@ //! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, //! which can be evaluated at runtime during query execution. -use std::any::Any; -use std::fmt::Debug; -use std::iter; -use std::ops::Range; -use std::sync::{Arc, LazyLock}; - use crate::define_udwf_and_expr; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::{Float64Array, UInt64Array}; use datafusion_common::arrow::compute::SortOptions; @@ -39,6 +34,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, LazyLock}; define_udwf_and_expr!( Rank, @@ -218,14 +218,14 @@ impl WindowUDFImpl for Rank { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_type = match self.rank_type { RankType::Basic | RankType::Dense => DataType::UInt64, RankType::Percent => DataType::Float64, }; let nullable = false; - Ok(Field::new(field_args.name(), return_type, nullable)) + Ok(Field::new(field_args.name(), return_type, nullable).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 330aed131fb13..ba8627dd86d79 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -17,6 +17,7 @@ //! `row_number` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; use datafusion_common::arrow::compute::SortOptions; @@ -106,8 +107,8 @@ impl WindowUDFImpl for RowNumber { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 229e888558095..1c7023f4497e6 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -22,6 +22,7 @@ use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let ascii = datafusion_functions::string::ascii(); @@ -41,19 +42,20 @@ fn criterion_benchmark(c: &mut Criterion) { UTF8_DENSITY_OF_ALL_ASCII, false, ); + + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( format!("ascii/string_ascii_only (null_density={null_density})").as_str(), |b| { b.iter(|| { black_box(ascii.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_ascii[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Arc::clone(&return_field), })) }) }, @@ -62,19 +64,18 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, false); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); c.bench_function( format!("ascii/string_utf8 (null_density={null_density})").as_str(), |b| { b.iter(|| { black_box(ascii.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_utf8[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Arc::clone(&return_field), })) }) }, @@ -88,6 +89,9 @@ fn criterion_benchmark(c: &mut Criterion) { UTF8_DENSITY_OF_ALL_ASCII, true, ); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); c.bench_function( format!("ascii/string_view_ascii_only (null_density={null_density})") .as_str(), @@ -95,13 +99,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(ascii.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_view_ascii[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Arc::clone(&return_field), })) }) }, @@ -110,19 +110,18 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); c.bench_function( format!("ascii/string_view_utf8 (null_density={null_density})").as_str(), |b| { b.iter(|| { black_box(ascii.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_view_utf8[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 270ee57f6429f..b4a9e917f4160 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -21,6 +21,7 @@ use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; mod helper; @@ -28,18 +29,19 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_field = Field::new("f", DataType::Utf8, true); + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); - let arg_fields_owned = args_string_ascii + let arg_fields = args_string_ascii .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_ascii_str_len_{str_len}"), |b| { @@ -48,7 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_ascii.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &return_field, + return_field: Arc::clone(&return_field), })) }) }, @@ -56,12 +58,13 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); - let arg_fields_owned = args_string_utf8 + let arg_fields = args_string_utf8 .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_utf8_str_len_{str_len}"), |b| { @@ -70,7 +73,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_utf8.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &return_field, + return_field: Arc::clone(&return_field), })) }) }, @@ -78,12 +81,13 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); - let arg_fields_owned = args_string_view_ascii + let arg_fields = args_string_view_ascii .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_ascii_str_len_{str_len}"), |b| { @@ -92,7 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_ascii.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &return_field, + return_field: Arc::clone(&return_field), })) }) }, @@ -100,12 +104,13 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); - let arg_fields_owned = args_string_view_utf8 + let arg_fields = args_string_view_utf8 .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_utf8_str_len_{str_len}"), |b| { @@ -114,7 +119,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_utf8.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &return_field, + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 2236424542f84..6a956bb788127 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -50,12 +50,11 @@ fn criterion_benchmark(c: &mut Criterion) { }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function("chr", |b| { b.iter(|| { @@ -65,7 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index ea6a9c2eaf638..d350c03c497bb 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -37,12 +37,13 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { @@ -54,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 634768b3f0303..a32e0d834672c 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -33,12 +33,13 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - let arg_fields_owned = f32_args + let arg_fields = f32_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { @@ -48,7 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }) .unwrap(), ) @@ -56,12 +57,14 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - let arg_fields_owned = f64_args + let arg_fields = f64_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Float64, true)); c.bench_function(&format!("cot f64 array: {size}"), |b| { b.iter(|| { @@ -71,7 +74,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7cc1989dc8f81..ac766a002576c 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -49,18 +49,19 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); - let return_field = Field::new("f", return_type, true); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields = vec![ + Field::new("a", interval.data_type(), true).into(), + Field::new("b", timestamps.data_type(), true).into(), + ]; b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], - arg_fields: vec![ - &Field::new("a", interval.data_type(), true), - &Field::new("b", timestamps.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &return_field, + return_field: Arc::clone(&return_field), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e40b7a0ad5e14..ad4d0d0fbb796 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -48,24 +48,25 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); - let return_field = Field::new("f", return_type, true); + let return_field = Arc::new(Field::new("f", return_type, true)); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &return_field, + return_field: Arc::clone(&return_field), }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 8210de82ef49c..830e0324766f7 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -35,17 +35,17 @@ fn criterion_benchmark(c: &mut Criterion) { .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], arg_fields: vec![ - &Field::new("a", str_array.data_type().to_owned(), true), - &Field::new("b", method.data_type().to_owned(), true), + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), ], number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); let arg_fields = vec![ - Field::new("a", encoded.data_type().to_owned(), true), - Field::new("b", method.data_type().to_owned(), true), + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), ]; let args = vec![encoded, method]; @@ -54,9 +54,9 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: arg_fields.iter().collect(), + arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -66,22 +66,23 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); let arg_fields = vec![ - Field::new("a", str_array.data_type().to_owned(), true), - Field::new("b", method.data_type().to_owned(), true), + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); let arg_fields = vec![ - Field::new("a", encoded.data_type().to_owned(), true), - Field::new("b", method.data_type().to_owned(), true), + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), ]; + let return_field = Field::new("f", DataType::Utf8, true).into(); let args = vec![encoded, method]; b.iter(|| { @@ -89,9 +90,9 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: arg_fields.iter().collect(), + arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index d1cb0f7025b37..bad540f049e28 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -153,35 +153,35 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); - let arg_fields_owned = args + let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Field::new("f", DataType::Int32, true).into(); group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); - let arg_fields_owned = args + let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }); @@ -191,35 +191,35 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); - let arg_fields_owned = args + let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); - let arg_fields_owned = args + let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 44a6e35931772..f700d31123a9d 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -49,11 +49,11 @@ fn criterion_benchmark(c: &mut Criterion) { udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], arg_fields: vec![ - &Field::new("a", array_a.data_type(), true), - &Field::new("b", array_b.data_type(), true), + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), ], number_rows: 0, - return_field: &Field::new("f", DataType::Int64, true), + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -69,11 +69,11 @@ fn criterion_benchmark(c: &mut Criterion) { udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], arg_fields: vec![ - &Field::new("a", array_a.data_type(), true), - &Field::new("b", scalar_b.data_type(), true), + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), ], number_rows: 0, - return_field: &Field::new("f", DataType::Int64, true), + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -89,11 +89,11 @@ fn criterion_benchmark(c: &mut Criterion) { udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], arg_fields: vec![ - &Field::new("a", scalar_a.data_type(), true), - &Field::new("b", scalar_b.data_type(), true), + Field::new("a", scalar_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), ], number_rows: 0, - return_field: &Field::new("f", DataType::Int64, true), + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index fbec2c4b28489..f89b11dff8fbe 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -49,12 +49,13 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( format!("initcap string view shorter than 12 [size={size}]").as_str(), @@ -64,7 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, @@ -79,7 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, @@ -92,7 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index abbf8ef789a62..49d0a9e326dd7 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -32,12 +32,13 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - let arg_fields_owned = f32_args + let arg_fields = f32_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("isnan f32 array: {size}"), |b| { b.iter(|| { @@ -47,7 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Boolean, true), + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) @@ -55,12 +56,13 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - let arg_fields_owned = f64_args + let arg_fields = f64_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("isnan f64 array: {size}"), |b| { b.iter(|| { black_box( @@ -69,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Boolean, true), + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index b35099738de55..6d1d34c7a8320 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -33,12 +33,14 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - let arg_fields_owned = f32_args + let arg_fields = f32_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { @@ -48,7 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Boolean, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -57,12 +59,14 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - let arg_fields_owned = f64_args + let arg_fields = f64_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); c.bench_function(&format!("iszero f64 array: {size}"), |b| { b.iter(|| { @@ -72,7 +76,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Boolean, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 5a21fbf52e20e..cdf1529c108c0 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -124,12 +124,13 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("lower_all_values_are_ascii: {size}"), |b| { b.iter(|| { @@ -138,18 +139,19 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); let args = create_args2(size); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("lower_the_first_value_is_nonascii: {size}"), |b| { b.iter(|| { @@ -158,18 +160,19 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); let args = create_args3(size); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("lower_the_middle_value_is_nonascii: {size}"), @@ -180,7 +183,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -197,14 +200,13 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true) + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), @@ -214,7 +216,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); @@ -228,7 +230,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); @@ -243,7 +245,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 6b2b7625ca73f..7a44f40a689a4 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -132,12 +132,11 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -149,7 +148,7 @@ fn run_with_string_type( args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 18d4e24cc7fc6..e1f609fbb35c0 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -63,19 +63,21 @@ fn criterion_benchmark(c: &mut Criterion) { let years = ColumnarValue::Array(years_array); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let arg_fields = vec![ + Field::new("a", years.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], - arg_fields: vec![ - &Field::new("a", years.data_type(), true), - &Field::new("a", months.data_type(), true), - &Field::new("a", days.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -89,19 +91,20 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = months_arr.len(); let months = ColumnarValue::Array(months_arr); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], - arg_fields: vec![ - &Field::new("a", year.data_type(), true), - &Field::new("a", months.data_type(), true), - &Field::new("a", days.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -115,19 +118,20 @@ fn criterion_benchmark(c: &mut Criterion) { let day_arr = Arc::new(days(&mut rng)); let batch_len = day_arr.len(); let days = ColumnarValue::Array(day_arr); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], - arg_fields: vec![ - &Field::new("a", year.data_type(), true), - &Field::new("a", month.data_type(), true), - &Field::new("a", days.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -138,19 +142,21 @@ fn criterion_benchmark(c: &mut Criterion) { let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", day.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], - arg_fields: vec![ - &Field::new("a", year.data_type(), true), - &Field::new("a", month.data_type(), true), - &Field::new("a", day.data_type(), true), - ], + arg_fields: arg_fields.clone(), number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index f754d0cbb9e39..4ac977af9d428 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -33,12 +33,13 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("nullif scalar array: {size}"), |b| { b.iter(|| { @@ -48,7 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index 87ff345443b63..d954ff452ed56 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -101,18 +101,17 @@ fn invoke_pad_with_args( number_rows: usize, left_pad: bool, ) -> Result { - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); let scalar_args = ScalarFunctionArgs { args: args.clone(), arg_fields, number_rows, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 5a80bbd1f11b4..dc1e280b93b13 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -21,10 +21,12 @@ use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let random_func = RandomFunc::new(); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 8192 let iterations = 1_000_000 / 8192; // Calculate how many iterations are needed to reach approximately 1M rows c.bench_function("random_1M_rows_batch_8192", |b| { @@ -36,7 +38,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 8192, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ); @@ -44,6 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 128 let iterations_128 = 1_000_000 / 128; // Calculate how many iterations are needed to reach approximately 1M rows with batch size 128 c.bench_function("random_1M_rows_batch_128", |b| { @@ -55,7 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 128, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 310e934f3a559..175933f5f745f 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -61,18 +61,17 @@ fn invoke_repeat_with_args( args: Vec, repeat_times: i64, ) -> Result { - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); string::repeat().invoke_with_args(ScalarFunctionArgs { args, arg_fields, number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index c45572b9eeebd..6403660113051 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -46,13 +46,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![&Field::new( + arg_fields: vec![Field::new( "a", args_string_ascii[0].data_type(), true, - )], + ).into()], number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -69,13 +69,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_utf8[0].data_type(), - true, - )], + arg_fields: vec![ + Field::new("a", args_string_utf8[0].data_type(), true).into(), + ], number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -95,13 +93,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![&Field::new( + arg_fields: vec![Field::new( "a", args_string_view_ascii[0].data_type(), true, - )], + ).into()], number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -118,13 +116,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![&Field::new( + arg_fields: vec![Field::new( "a", args_string_view_utf8[0].data_type(), true, - )], + ).into()], number_rows: N_ROWS, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 2cfc9ce103044..10079bcc81c7d 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -33,12 +33,14 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - let arg_fields_owned = f32_args + let arg_fields = f32_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Field::new("f", DataType::Float32, true).into(); c.bench_function(&format!("signum f32 array: {size}"), |b| { b.iter(|| { @@ -48,7 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -58,12 +60,14 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - let arg_fields_owned = f64_args + let arg_fields = f64_args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + let return_field = Field::new("f", DataType::Float64, true).into(); c.bench_function(&format!("signum f64 array: {size}"), |b| { b.iter(|| { @@ -73,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 23fd634b65cf9..df32db1182f1f 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -111,19 +111,18 @@ fn criterion_benchmark(c: &mut Criterion) { for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( &format!("strpos_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_ascii[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }, @@ -131,36 +130,34 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_utf8[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_view_ascii[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }, @@ -168,19 +165,18 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![&Field::new( - "a", - args_string_view_utf8[0].data_type(), - true, - )], + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 76e82e112e104..342e18b0d9a2e 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -101,18 +101,17 @@ fn invoke_substr_with_args( args: Vec, number_rows: usize, ) -> Result { - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); unicode::substr().invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields, number_rows, - return_field: &Field::new("f", DataType::Utf8View, true), + return_field: Field::new("f", DataType::Utf8View, true).into(), }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 536d8c019db5b..e772fb38fc400 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -91,12 +91,13 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); b.iter(|| { black_box( @@ -105,7 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index ef51a017a2fba..d19714ce61664 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -94,11 +94,11 @@ fn criterion_benchmark(c: &mut Criterion) { .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], arg_fields: vec![ - &Field::new("a", data.data_type(), true), - &Field::new("b", patterns.data_type(), true), + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), ], number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -119,11 +119,11 @@ fn criterion_benchmark(c: &mut Criterion) { .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], arg_fields: vec![ - &Field::new("a", data.data_type(), true), - &Field::new("b", patterns.data_type(), true), + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), ], number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -150,11 +150,11 @@ fn criterion_benchmark(c: &mut Criterion) { .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], arg_fields: vec![ - &Field::new("a", data.data_type(), true), - &Field::new("b", pattern.data_type(), true), + Field::new("a", data.data_type(), true).into(), + Field::new("b", pattern.data_type(), true).into(), ], number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index ba104457a15cb..4a02b74ca42d1 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -36,9 +36,9 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![&Field::new("a", DataType::Int32, false)], + arg_fields: vec![Field::new("a", DataType::Int32, false).into()], number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -53,9 +53,9 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![&Field::new("a", DataType::Int64, false)], + arg_fields: vec![Field::new("a", DataType::Int64, false).into()], number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index ae38b3ae9df68..d898113484899 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -110,9 +110,9 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { } fn criterion_benchmark(c: &mut Criterion) { let return_field = - &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); - let arg_field = Field::new("a", DataType::Utf8, false); - let arg_fields = vec![&arg_field]; + Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(); + let arg_field = Field::new("a", DataType::Utf8, false).into(); + let arg_fields = vec![arg_field]; c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -125,7 +125,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -144,7 +144,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -163,7 +163,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -180,12 +180,13 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); b.iter(|| { black_box( @@ -194,7 +195,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -219,12 +220,13 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); b.iter(|| { black_box( @@ -233,7 +235,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -259,12 +261,13 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); b.iter(|| { black_box( @@ -273,7 +276,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: batch_len, - return_field, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index e55e4dcd0805b..897e21c1e1d94 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -33,15 +33,17 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field = Field::new("f", DataType::Float32, true).into(); c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![&Field::new("a", DataType::Float32, false)], + arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -49,15 +51,17 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let return_field = Field::new("f", DataType::Float64, true).into(); c.bench_function(&format!("trunc f64 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![&Field::new("a", DataType::Float64, false)], + arg_fields: arg_fields.clone(), number_rows: size, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index e218f6d0372a8..bf2c4161001e8 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -42,9 +42,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![&Field::new("a", DataType::Utf8, true)], + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index dfed6871d6c21..942af122562ab 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -30,7 +30,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 1024, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 0e18ec180cefc..2d769dfa56579 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -116,7 +116,7 @@ impl ScalarUDFImpl for ArrowCastFunc { internal_err!("return_field_from_args should be called instead") } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index b2ca3692c1d38..12a4bef247393 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -18,7 +18,7 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::{ @@ -82,7 +82,7 @@ impl ScalarUDFImpl for CoalesceFunc { internal_err!("return_field_from_args should be called instead") } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args @@ -92,7 +92,7 @@ impl ScalarUDFImpl for CoalesceFunc { .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(Field::new(self.name(), return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable).into()) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 97df76eaac58a..de87308ef3c49 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,7 +20,7 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ @@ -133,7 +133,7 @@ impl ScalarUDFImpl for GetFieldFunc { internal_err!("return_field_from_args should be called instead") } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); @@ -147,7 +147,7 @@ impl ScalarUDFImpl for GetFieldFunc { // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.as_ref().clone().with_nullable(true)) + Ok(value_field.as_ref().clone().with_nullable(true).into()) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -168,11 +168,11 @@ impl ScalarUDFImpl for GetFieldFunc { if args.arg_fields[0].is_nullable() { child_field = child_field.with_nullable(true); } - child_field + Arc::new(child_field) }) }) }, - (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 9346b62b90ed8..115f4a8aba225 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::StructArray; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, @@ -96,7 +96,7 @@ impl ScalarUDFImpl for NamedStructFunc { ) } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -146,7 +146,8 @@ impl ScalarUDFImpl for NamedStructFunc { self.name(), DataType::Struct(Fields::from(return_fields)), true, - )) + ) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index b1544a9b357b7..be49f82267121 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -86,7 +86,7 @@ impl ScalarUDFImpl for UnionExtractFun { internal_err!("union_extract should return type from args") } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", @@ -110,7 +110,7 @@ impl ScalarUDFImpl for UnionExtractFun { let field = find_field(fields, field_name)?.1; - Ok(Field::new(self.name(), field.data_type().clone(), true)) + Ok(Field::new(self.name(), field.data_type().clone(), true).into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -199,14 +199,14 @@ mod tests { ]; let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -221,14 +221,14 @@ mod tests { ]; let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -243,13 +243,13 @@ mod tests { ]; let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 2997313f9efea..3a4d96de2bc03 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -180,7 +180,7 @@ mod tests { .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], number_rows: 1, - return_field: &Field::new("res", return_type, true), + return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], }) .unwrap(); @@ -202,7 +202,7 @@ mod tests { .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], number_rows: 1, - return_field: &Field::new("res", return_type, true), + return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 9e243dd0adb87..b3abe246b4b3f 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -108,7 +108,7 @@ mod test { args: vec![], arg_fields: vec![], number_rows: 0, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index ea9e3d091860a..1c801dfead723 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -505,7 +505,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, Field, TimeUnit}; + use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion_common::{DataFusionError, ScalarValue}; @@ -516,26 +516,29 @@ mod tests { fn invoke_date_bin_with_args( args: Vec, number_rows: usize, - return_field: &Field, + return_field: &FieldRef, ) -> Result { let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type(), true)) + .map(|arg| Field::new("a", arg.data_type(), true).into()) .collect::>(); let args = datafusion_expr::ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows, - return_field, + return_field: Arc::clone(return_field), }; DateBinFunc::new().invoke_with_args(args) } #[test] fn test_date_bin() { - let return_field = - &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )); let mut args = vec![ ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { @@ -853,11 +856,11 @@ mod tests { tz_opt.clone(), )), ]; - let return_field = &Field::new( + let return_field = &Arc::new(Field::new( "f", DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), true, - ); + )); let result = invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 91f983e0acc3f..021000dc100b8 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -145,7 +145,7 @@ impl ScalarUDFImpl for DatePartFunc { internal_err!("return_field_from_args should be called instead") } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -161,6 +161,7 @@ impl ScalarUDFImpl for DatePartFunc { } }) }) + .map(Arc::new) .map_or_else( || exec_err!("{} requires non-empty constant string", self.name()), Ok, diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 3d29dc45a920c..8963ef77a53b9 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -727,21 +727,22 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); let arg_fields = vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", input.data_type().clone(), false), + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: batch_len, - return_field: &Field::new( + return_field: Field::new( "f", DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), true, - ), + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -898,21 +899,22 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); let arg_fields = vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", input.data_type().clone(), false), + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: batch_len, - return_field: &Field::new( + return_field: Field::new( "f", DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), true, - ), + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d6d7878932bf3..c1497040261ca 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -81,12 +81,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(Field::new(self.name(), Timestamp(Second, None), true)) + Ok(Field::new(self.name(), Timestamp(Second, None), true).into()) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -101,6 +101,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { ) }) }) + .map(Arc::new) .map_or_else( || { exec_err!( @@ -170,12 +171,12 @@ mod test { #[test] fn test_without_timezone() { - let arg_field = Field::new("a", DataType::Int64, true); + let arg_field = Arc::new(Field::new("a", DataType::Int64, true)); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", DataType::Timestamp(Second, None), true), + return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -190,8 +191,8 @@ mod test { #[test] fn test_with_timezone() { let arg_fields = vec![ - Field::new("a", DataType::Int64, true), - Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Utf8, true).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ @@ -200,13 +201,14 @@ mod test { "America/New_York".to_string(), ))), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 2, - return_field: &Field::new( + return_field: Field::new( "f", DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), true, - ), + ) + .into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ed901258cd62b..daa9bd83971f9 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -234,13 +234,13 @@ mod tests { ) -> Result { let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type(), true)) + .map(|arg| Field::new("a", arg.data_type(), true).into()) .collect::>(); let args = datafusion_expr::ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Field::new("f", DataType::Date32, true).into(), }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 867442df45ad9..30b4d4ca9c76f 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -17,7 +17,7 @@ use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; @@ -77,12 +77,13 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { Ok(Field::new( self.name(), Timestamp(Nanosecond, Some("+00:00".into())), false, - )) + ) + .into()) } fn return_type(&self, _arg_types: &[DataType]) -> Result { diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index be3917092ba9d..3e89242aba263 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -386,14 +386,14 @@ mod tests { for (value, format, expected) in scalar_data { let arg_fields = vec![ - Field::new("a", value.data_type(), false), - Field::new("a", format.data_type(), false), + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -471,17 +471,17 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); let arg_fields = vec![ - Field::new("a", value.data_type(), false), - Field::new("a", format.data_type().to_owned(), false), + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type().to_owned(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -607,17 +607,17 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); let arg_fields = vec![ - Field::new("a", value.data_type().clone(), false), - Field::new("a", format.data_type(), false), + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -634,17 +634,17 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); let arg_fields = vec![ - Field::new("a", value.data_type().clone(), false), - Field::new("a", format.data_type().clone(), false), + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type().clone(), false).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: batch_len, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -663,12 +663,12 @@ mod tests { // // invalid number of arguments - let arg_field = Field::new("a", DataType::Int32, true); + let arg_field = Field::new("a", DataType::Int32, true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -678,17 +678,17 @@ mod tests { // invalid type let arg_fields = vec![ - Field::new("a", DataType::Utf8, true), - Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(), ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Utf8, true), + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 372d0e7be9a52..c9fd17dbef11f 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -177,14 +177,14 @@ mod tests { ) -> Result { let arg_fields = args .iter() - .map(|arg| Field::new("a", arg.data_type(), true)) + .map(|arg| Field::new("a", arg.data_type(), true).into()) .collect::>(); let args = datafusion_expr::ScalarFunctionArgs { args, - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows, - return_field: &Field::new("f", DataType::Date32, true), + return_field: Field::new("f", DataType::Date32, true).into(), }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 5cf9b785b5035..b9ebe537d459b 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -538,13 +538,13 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { - let arg_field = Field::new("a", input.data_type(), true); + let arg_field = Field::new("a", input.data_type(), true).into(); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", expected.data_type(), true), + return_field: Field::new("f", expected.data_type(), true).into(), }) .unwrap(); match res { @@ -604,16 +604,17 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); - let arg_field = Field::new("a", input.data_type().clone(), true); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: batch_size, - return_field: &Field::new( + return_field: Field::new( "f", DataType::Timestamp(TimeUnit::Nanosecond, None), true, - ), + ) + .into(), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index c6aab61328eb6..8b26a1c259505 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1012,13 +1012,13 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); - let arg_field = Field::new("arg", array.data_type().clone(), true); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 4, - return_field: &Field::new("f", rt, true), + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) @@ -1062,12 +1062,12 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); - let arg_field = Field::new("arg", array.data_type().clone(), true); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 5, - return_field: &Field::new("f", rt, true), + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index d1f40e3b1ad1b..ee52c035ac81d 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -266,8 +266,8 @@ mod tests { #[should_panic] fn test_log_invalid_base_type() { let arg_fields = vec![ - Field::new("a", DataType::Float64, false), - Field::new("a", DataType::Int64, false), + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Int64, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ @@ -276,23 +276,23 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 4, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { - let arg_field = Field::new("a", DataType::Int64, false); + let arg_field = Field::new("a", DataType::Int64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new().invoke_with_args(args); @@ -301,14 +301,14 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { - let arg_field = Field::new("a", DataType::Float32, false); + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -330,14 +330,14 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { - let arg_field = Field::new("a", DataType::Float64, false); + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 1, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -360,17 +360,17 @@ mod tests { #[test] fn test_log_scalar_f32() { let arg_fields = vec![ - Field::new("a", DataType::Float32, false), - Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -393,17 +393,17 @@ mod tests { #[test] fn test_log_scalar_f64() { let arg_fields = vec![ - Field::new("a", DataType::Float64, false), - Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 1, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -425,16 +425,16 @@ mod tests { #[test] fn test_log_f64_unary() { - let arg_field = Field::new("a", DataType::Float64, false); + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 4, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -459,16 +459,16 @@ mod tests { #[test] fn test_log_f32_unary() { - let arg_field = Field::new("a", DataType::Float32, false); + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_fields: vec![&arg_field], + arg_fields: vec![arg_field], number_rows: 4, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -494,8 +494,8 @@ mod tests { #[test] fn test_log_f64() { let arg_fields = vec![ - Field::new("a", DataType::Float64, false), - Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ @@ -506,9 +506,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 4, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -534,8 +534,8 @@ mod tests { #[test] fn test_log_f32() { let arg_fields = vec![ - Field::new("a", DataType::Float32, false), - Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ @@ -546,9 +546,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 4, - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 8876e3fe27875..bd1ae7c316c1a 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -195,8 +195,8 @@ mod tests { #[test] fn test_power_f64() { let arg_fields = vec![ - Field::new("a", DataType::Float64, true), - Field::new("a", DataType::Float64, true), + Field::new("a", DataType::Float64, true).into(), + Field::new("a", DataType::Float64, true).into(), ]; let args = ScalarFunctionArgs { args: vec![ @@ -207,9 +207,9 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 4, - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -234,17 +234,17 @@ mod tests { #[test] fn test_power_i64() { let arg_fields = vec![ - Field::new("a", DataType::Int64, true), - Field::new("a", DataType::Int64, true), + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Int64, true).into(), ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 4, - return_field: &Field::new("f", DataType::Int64, true), + return_field: Field::new("f", DataType::Int64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 7414e6e138abe..ec6ef5a78c6a7 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -157,12 +157,12 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); - let arg_fields = [Field::new("a", DataType::Float32, false)]; + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: array.len(), - return_field: &Field::new("f", DataType::Float32, true), + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -203,12 +203,12 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); - let arg_fields = [Field::new("a", DataType::Float64, false)]; + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: array.len(), - return_field: &Field::new("f", DataType::Float64, true), + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8f53bf8eb1587..52ab3d489ee31 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -651,18 +651,17 @@ mod tests { .map(|sv| ColumnarValue::Scalar(sv.clone())) .collect(); - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) - .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into()) + .collect::>(); RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: args_values, arg_fields, number_rows: args.len(), - return_field: &Field::new("f", Int64, true), + return_field: Field::new("f", Int64, true).into(), }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index fe0a5915fe20c..773c316422b70 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -475,13 +475,16 @@ mod tests { Field::new("a", Utf8, true), Field::new("a", Utf8View, true), Field::new("a", Utf8View, true), - ]; + ] + .into_iter() + .map(Arc::new) + .collect::>(); let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 3, - return_field: &Field::new("f", Utf8, true), + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 79a5d34fb4c4e..2a2f9429f8fc3 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -482,15 +482,15 @@ mod tests { ]))); let arg_fields = vec![ - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 3, - return_field: &Field::new("f", Utf8, true), + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -518,15 +518,15 @@ mod tests { ]))); let arg_fields = vec![ - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 3, - return_field: &Field::new("f", Utf8, true), + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index b4c9cd90ca5d4..b74be15466265 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -166,15 +166,15 @@ mod test { ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); let arg_fields = vec![ - Field::new("a", DataType::Utf8, true), - Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Utf8, true).into(), ]; let args = ScalarFunctionArgs { args: vec![array, scalar], - arg_fields: arg_fields.iter().collect(), + arg_fields, number_rows: 2, - return_field: &Field::new("f", DataType::Boolean, true), + return_field: Field::new("f", DataType::Boolean, true).into(), }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 1dc6e9d283677..536c29a7cb253 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -104,13 +104,13 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); - let arg_fields = [Field::new("a", input.data_type().clone(), true)]; + let arg_fields = vec![Field::new("a", input.data_type().clone(), true).into()]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields: arg_fields.iter().collect(), - return_field: &Field::new("f", Utf8, true), + arg_fields, + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 06a9bd9720d6c..882fb45eda4af 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -104,12 +104,12 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); - let arg_field = Field::new("a", input.data_type().clone(), true); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields: vec![&arg_field], - return_field: &Field::new("f", Utf8, true), + arg_fields: vec![arg_field], + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 6b5df89e860f8..8b00c7be1ccf8 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -471,17 +471,18 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; - let arg_fields_owned = args + let arg_fields = args .iter() .enumerate() - .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .map(|(idx, a)| { + Field::new(format!("arg_{idx}"), a.data_type(), true).into() + }) .collect::>(); - let arg_fields = arg_fields_owned.iter().collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, arg_fields, number_rows: cardinality, - return_field: &Field::new("f", return_type, true), + return_field: Field::new("f", return_type, true).into(), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b33a1ca7713af..1c81b46ec78ea 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,9 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Field, Int32Type, Int64Type}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, +}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -94,7 +96,7 @@ impl ScalarUDFImpl for StrposFunc { fn return_field_from_args( &self, args: datafusion_expr::ReturnFieldArgs, - ) -> Result { + ) -> Result { utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( |data_type| { Field::new( @@ -102,6 +104,7 @@ impl ScalarUDFImpl for StrposFunc { data_type, args.arg_fields.iter().any(|x| x.is_nullable()), ) + .into() }, ) } @@ -329,8 +332,8 @@ mod tests { let strpos = StrposFunc::new(); let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[ - Field::new("f1", DataType::Utf8, string_array_nullable), - Field::new("f2", DataType::Utf8, substring_nullable), + Field::new("f1", DataType::Utf8, string_array_nullable).into(), + Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 6557f7da81ce2..583ff48bff39d 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -155,15 +155,16 @@ pub mod test { let field_array = data_array.into_iter().zip(nullables).enumerate() .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .map(std::sync::Arc::new) .collect::>(); let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, }); - let arg_fields_owned = $ARGS.iter() + let arg_fields = $ARGS.iter() .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true)) + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) .collect::>(); match expected { @@ -173,8 +174,7 @@ pub mod test { let return_type = return_field.data_type(); assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let arg_fields = arg_fields_owned.iter().collect::>(); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -197,9 +197,8 @@ pub mod test { else { let return_field = return_field.unwrap(); - let arg_fields = arg_fields_owned.iter().collect::>(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c17e6c766cc3f..c8246ecebd543 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -813,7 +813,7 @@ fn coerce_arguments_for_signature_with_aggregate_udf( let current_fields = expressions .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; let new_types = fields_with_aggregate_udf(¤t_fields, func)? @@ -1622,8 +1622,8 @@ mod test { return_type, accumulator, vec![ - Field::new("count", DataType::UInt64, true), - Field::new("avg", DataType::Float64, true), + Field::new("count", DataType::UInt64, true).into(), + Field::new("avg", DataType::Float64, true).into(), ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d526b63ae5d2c..6a49e5d22087f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -904,7 +904,7 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), Arc::clone(&accumulator), - vec![Field::new("value", DataType::UInt32, true)], + vec![Field::new("value", DataType::UInt32, true).into()], ))), vec![inner], false, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04ca471309984..4e4e3d316c268 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2144,8 +2144,10 @@ fn simplify_null_div_other_case( #[cfg(test)] mod tests { + use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; + use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ function::{ @@ -2163,8 +2165,6 @@ mod tests { sync::Arc, }; - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -4451,7 +4451,7 @@ mod tests { unimplemented!("not needed for tests") } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!("not needed for tests") } } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index b0c02f8761d41..7be132fa61238 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -25,7 +25,7 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; @@ -81,12 +81,12 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; /// The output field associated with this expression - fn return_field(&self, input_schema: &Schema) -> Result { - Ok(Field::new( + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( format!("{self}"), self.data_type(input_schema)?, self.nullable(input_schema)?, - )) + ))) } /// Evaluate an expression against a RecordBatch after first applying a /// validity array @@ -469,7 +469,7 @@ where /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; @@ -479,7 +479,7 @@ where /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } -/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index a1227c7f11433..2572e8679484f 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -44,7 +44,7 @@ use itertools::Itertools; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -57,7 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } -/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 867b4e0fc9555..be04b9c6b8ea8 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -41,7 +41,7 @@ use std::sync::Arc; use crate::expressions::Column; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; @@ -106,7 +106,7 @@ impl AggregateExprBuilder { /// ``` /// # use std::any::Any; /// # use std::sync::Arc; - /// # use arrow::datatypes::DataType; + /// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{Result, ScalarValue}; /// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -143,7 +143,7 @@ impl AggregateExprBuilder { /// # unimplemented!() /// # } /// # - /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { + /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// # unimplemented!() /// # } /// # @@ -311,7 +311,7 @@ pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, /// Output / return field of this aggregate - return_field: Field, + return_field: FieldRef, /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. @@ -322,10 +322,10 @@ pub struct AggregateFunctionExpr { // Whether to ignore null values ignore_nulls: bool, // fields used for order sensitive aggregation functions - ordering_fields: Vec, + ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_fields: Vec, + input_fields: Vec, is_nullable: bool, } @@ -372,8 +372,12 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn field(&self) -> Field { - self.return_field.clone().with_name(&self.name) + pub fn field(&self) -> FieldRef { + self.return_field + .as_ref() + .clone() + .with_name(&self.name) + .into() } /// the accumulator used to accumulate values from the expressions. @@ -381,7 +385,7 @@ impl AggregateFunctionExpr { /// return states with the same description as `state_fields` pub fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -395,11 +399,11 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn state_fields(&self) -> Result> { + pub fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, input_fields: &self.input_fields, - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, }; @@ -472,7 +476,7 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -541,7 +545,7 @@ impl AggregateFunctionExpr { /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -560,7 +564,7 @@ impl AggregateFunctionExpr { /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -640,7 +644,7 @@ impl AggregateFunctionExpr { /// output_field is the name of the column produced by this aggregate /// /// Note: this is used to use special aggregate implementations in certain conditions - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { self.fun.is_descending().map(|flag| (self.field(), flag)) } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 39a01391e3260..fa52ae8686f76 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -1224,7 +1224,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - Field::new("f", DataType::Utf8, true), + Field::new("f", DataType::Utf8, true).into(), )); // Assume existing ordering is [c ASC, a ASC, b ASC] @@ -1315,7 +1315,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - Field::new("f", DataType::Utf8, true), + Field::new("f", DataType::Utf8, true).into(), )); // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 88923d9c6ceea..7e345e60271fd 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Field, Schema}; +use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -144,11 +144,14 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { Ok(self .expr .return_field(input_schema)? - .with_data_type(self.cast_type.clone())) + .as_ref() + .clone() + .with_data_type(self.cast_type.clone()) + .into()) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 80af0b84c5d17..5a11783a87e90 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,8 +22,9 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, + datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; use datafusion_common::tree_node::{Transformed, TreeNode}; @@ -127,8 +128,8 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn return_field(&self, input_schema: &Schema) -> Result { - Ok(input_schema.field(self.index).clone()) + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone().into()) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 1de8c17a373a7..ff05dab40126a 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -18,7 +18,7 @@ //! IS NOT NULL expression use crate::PhysicalExpr; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -94,7 +94,7 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { self.arg.return_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 7707075ce653f..15c7c645bda09 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -18,8 +18,9 @@ //! IS NULL expression use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::Result; @@ -92,7 +93,7 @@ impl PhysicalExpr for IsNullExpr { } } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { self.arg.return_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 597cbf1dac9ec..fa7224768a777 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -104,7 +104,7 @@ impl PhysicalExpr for NegativeExpr { } } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { self.arg.return_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 1f3ae9e25ffb5..8184ef601e543 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -101,7 +101,7 @@ impl PhysicalExpr for NotExpr { } } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { self.arg.return_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e4fe027c79180..b593dfe83209d 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; @@ -110,10 +110,11 @@ impl PhysicalExpr for TryCastExpr { } } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { self.expr .return_field(input_schema) - .map(|f| f.with_data_type(self.cast_type.clone())) + .map(|f| f.as_ref().clone().with_data_type(self.cast_type.clone())) + .map(Arc::new) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index d6e070e389484..d014bbb74caa1 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -38,7 +38,7 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -53,7 +53,7 @@ pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_field: Field, + return_field: FieldRef, } impl Debug for ScalarFunctionExpr { @@ -73,7 +73,7 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_field: Field, + return_field: FieldRef, ) -> Self { Self { fun, @@ -144,7 +144,12 @@ impl ScalarFunctionExpr { } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.return_field = self.return_field.with_nullable(nullable); + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); self } @@ -180,12 +185,11 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; - let arg_fields_owned = self + let arg_fields = self .args .iter() .map(|e| e.return_field(batch.schema_ref())) .collect::>>()?; - let arg_fields = arg_fields_owned.iter().collect::>(); let input_empty = args.is_empty(); let input_all_scalar = args @@ -197,7 +201,7 @@ impl PhysicalExpr for ScalarFunctionExpr { args, arg_fields, number_rows: batch.num_rows(), - return_field: &self.return_field, + return_field: Arc::clone(&self.return_field), })?; if let ColumnarValue::Array(array) = &output { @@ -217,8 +221,8 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn return_field(&self, _input_schema: &Schema) -> Result { - Ok(self.return_field.clone()) + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) } fn children(&self) -> Vec<&Arc> { @@ -233,7 +237,7 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, Arc::clone(&self.fun), children, - self.return_field.clone(), + Arc::clone(&self.return_field), ))) } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index a94d5b1212f52..9b959796136a9 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -30,8 +30,9 @@ use crate::window::{ use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -95,7 +96,7 @@ impl WindowExpr for PlainAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 23967e78f07a7..2b22299f9386b 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -29,7 +29,7 @@ use crate::window::{ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; @@ -80,7 +80,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 22e8aea83fe78..73f47b0b68632 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -27,7 +27,7 @@ use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; use arrow::array::{new_empty_array, ArrayRef}; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; @@ -92,7 +92,7 @@ impl WindowExpr for StandardWindowExpr { self.expr.name() } - fn field(&self) -> Result { + fn field(&self) -> Result { self.expr.field() } diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index 624b747d93f9a..871f735e9a963 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -18,7 +18,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{Field, SchemaRef}; +use arrow::datatypes::{FieldRef, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr::PartitionEvaluator; @@ -41,7 +41,7 @@ pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of evaluating this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self) -> Vec>; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 793f2e5ee5867..8d72604a6af50 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -25,7 +25,7 @@ use crate::{LexOrdering, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::compare_rows; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -67,7 +67,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2062e2208b40e..fdae2aa945069 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,6 +36,7 @@ use crate::{ use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; @@ -273,7 +274,7 @@ impl PhysicalGroupBy { } /// Returns the fields that are used as the grouping keys. - fn group_fields(&self, input_schema: &Schema) -> Result> { + fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); for ((expr, name), group_expr_nullable) in self.expr.iter().zip(self.exprs_nullable().into_iter()) @@ -284,15 +285,19 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata(expr.return_field(input_schema)?.metadata().clone()), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()) + .into(), ); } if !self.is_single() { - fields.push(Field::new( - Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), - false, - )); + fields.push( + Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + ) + .into(), + ); } Ok(fields) } @@ -301,7 +306,7 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields(&self, input_schema: &Schema) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; fields.truncate(self.num_output_exprs()); Ok(fields) @@ -621,7 +626,7 @@ impl AggregateExec { } /// Finds the DataType and SortDirection for this Aggregate, if there is one - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; agg_expr.get_minmax_desc() } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index f773391a6a704..d2b7e0a49e951 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -30,8 +30,8 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::{Field, Schema, SchemaRef}; -use arrow_schema::SortOptions; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, @@ -84,7 +84,10 @@ pub fn schema_add_window_field( if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { Ok(Arc::new(Schema::new(window_fields))) } else { - window_fields.extend_from_slice(&[window_expr_return_field.with_name(fn_name)]); + window_fields.extend_from_slice(&[window_expr_return_field + .as_ref() + .clone() + .with_name(fn_name)]); Ok(Arc::new(Schema::new(window_fields))) } } @@ -199,7 +202,7 @@ pub struct WindowUDFExpr { /// Display name name: String, /// Fields of input expressions - input_fields: Vec, + input_fields: Vec, /// This is set to `true` only if the user-defined window function /// expression supports evaluation in reverse order, and the /// evaluation order is reversed. @@ -219,7 +222,7 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { self.fun .field(WindowUDFFieldArgs::new(&self.input_fields, &self.name)) } @@ -637,7 +640,7 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use arrow::compute::SortOptions; - use arrow_schema::DataType; + use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::count::count_udaf; diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index ef79f4b43e66a..5024bb558a65a 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -368,7 +368,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun_def, args, - Field::new("f", convert_required!(e.return_type)?, true), + Field::new("f", convert_required!(e.return_type)?, true).into(), ) .with_nullable(e.nullable), ) diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 92d961fc75562..4c7da2768e744 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; @@ -166,8 +166,11 @@ impl WindowUDFImpl for CustomUDWF { Ok(Box::new(CustomUDWFEvaluator {})) } - fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field( + &self, + field_args: WindowUDFFieldArgs, + ) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 369700bded04e..b515ef6e38de2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -19,9 +19,9 @@ use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ - DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, - DECIMAL256_MAX_PRECISION, + DataType, Field, FieldRef, Fields, Int32Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, + UnionMode, DECIMAL256_MAX_PRECISION, }; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; @@ -2516,9 +2516,13 @@ fn roundtrip_window() { make_partition_evaluator() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { if let Some(return_field) = field_args.get_input_field(0) { - Ok(return_field.with_name(field_args.name())) + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } else { plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index ad4c695b9ef16..7d56bb6c5db1b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -595,7 +595,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - vec![Field::new("value", DataType::Int64, true)], + vec![Field::new("value", DataType::Int64, true).into()], )); let ctx = SessionContext::new(); @@ -981,7 +981,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - Field::new("f", DataType::Int64, true), + Field::new("f", DataType::Int64, true).into(), ); let project = @@ -1109,7 +1109,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - Field::new("f", DataType::Int64, true), + Field::new("f", DataType::Int64, true).into(), )); let filter = Arc::new(FilterExec::try_new( @@ -1211,7 +1211,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - Field::new("f", DataType::Int64, true), + Field::new("f", DataType::Int64, true).into(), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index 67ec76ba524e9..85af4bb927ca5 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -28,7 +28,7 @@ pub mod test { let expected: datafusion_common::Result> = $EXPECTED; let func = $FUNC; - let arg_fields_owned: Vec = $ARGS + let arg_fields: Vec = $ARGS .iter() .enumerate() .map(|(idx, arg)| { @@ -38,12 +38,10 @@ pub mod test { datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, }; - arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable) + std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable)) }) .collect::>(); - let arg_fields: Vec<&arrow::datatypes::Field> = arg_fields_owned.iter().collect(); - let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -60,7 +58,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { - arg_fields: &arg_fields_owned, + arg_fields: &arg_fields, scalar_arguments: &scalar_arguments_refs }); @@ -72,7 +70,7 @@ pub mod test { let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ args: $ARGS, number_rows: cardinality, - return_field: &return_field, + return_field, arg_fields: arg_fields.clone(), }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -101,7 +99,7 @@ pub mod test { match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ args: $ARGS, number_rows: cardinality, - return_field: &return_field, + return_field, arg_fields, }) { Ok(_) => assert!(false, "expected error"), diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index a6ff73e13dc90..ed8fdadab2373 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -36,17 +36,17 @@ ListingOptions::new(Arc::new(ParquetFormat::default())) # */ ``` -### Processing `Field` instead of `DataType` for user defined functions +### Processing `FieldRef` instead of `DataType` for user defined functions In order to support metadata handling and extension types, user defined functions are -now switching to traits which use `Field` rather than a `DataType` and nullability. +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. This gives a single interface to both of these parameters and additionally allows access to metadata fields, which can be used for extension types. To upgrade structs which implement `ScalarUDFImpl`, if you have implemented `return_type_from_args` you need instead to implement `return_field_from_args`. If your functions do not need to handle metadata, this should be straightforward -repackaging of the output data into a `Field`. The name you specify on the +repackaging of the output data into a `FieldRef`. The name you specify on the field is not important. It will be overwritten during planning. `ReturnInfo` has been removed, so you will need to remove all references to it. @@ -59,7 +59,7 @@ your function. You are not required to implement this if you do not need to handle metatdata. The largest change to aggregate functions happens in the accumulator arguments. -Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `Field` rather +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather than `DataType`. To upgrade window functions, `ExpressionArgs` now contains input fields instead