diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index b6a3a4161ee..0a1ef0b6a7d 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -1636,7 +1636,7 @@ mod tests { .file_extension(".out"); let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?; let df = df.select( - &get_answer_schema(n) + get_answer_schema(n) .fields() .iter() .map(|field| { diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 4dd0c3e3f7e..23cf5ef4db7 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -100,7 +100,7 @@ async fn main() -> datafusion::error::Result<()> { let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; let df = df.filter(col("a").lt_eq(col("b")))? - .aggregate(&[col("a")], &[min(col("b"))])? + .aggregate(vec![col("a")], vec![min(col("b"))])? .limit(100)?; // execute and print results diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index 55aa350b13d..8086dfc47de 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -148,7 +148,7 @@ async fn main() -> Result<()> { let df = ctx.table("t")?; // perform the aggregation - let df = df.aggregate(&[], &[geometric_mean.call(vec![col("a")])])?; + let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index 00debdbddac..bfef1089a63 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -133,7 +133,7 @@ async fn main() -> Result<()> { let expr1 = pow.call(vec![col("a"), col("b")]); // equivalent to `'SELECT pow(a, b), pow(a, b) AS pow1 FROM t'` - let df = df.select(&[ + let df = df.select(vec![ expr, // alias so that they have different column names expr1.alias("pow1"), diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index b3e561100c7..9c7c2ef96d6 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -44,7 +44,7 @@ use async_trait::async_trait; /// let mut ctx = ExecutionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; /// let df = df.filter(col("a").lt_eq(col("b")))? -/// .aggregate(&[col("a")], &[min(col("b"))])? +/// .aggregate(vec![col("a")], vec![min(col("b"))])? /// .limit(100)?; /// let results = df.collect(); /// # Ok(()) @@ -75,11 +75,11 @@ pub trait DataFrame: Send + Sync { /// # fn main() -> Result<()> { /// let mut ctx = ExecutionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; - /// let df = df.select(&[col("a") * col("b"), col("c")])?; + /// let df = df.select(vec![col("a") * col("b"), col("c")])?; /// # Ok(()) /// # } /// ``` - fn select(&self, expr: &[Expr]) -> Result>; + fn select(&self, expr: Vec) -> Result>; /// Filter a DataFrame to only include rows that match the specified filter expression. /// @@ -105,17 +105,17 @@ pub trait DataFrame: Send + Sync { /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" - /// let _ = df.aggregate(&[col("a")], &[min(col("b"))])?; + /// let _ = df.aggregate(vec![col("a")], vec![min(col("b"))])?; /// /// // The following use is the equivalent of "SELECT MIN(b)" - /// let _ = df.aggregate(&[], &[min(col("b"))])?; + /// let _ = df.aggregate(vec![], vec![min(col("b"))])?; /// # Ok(()) /// # } /// ``` fn aggregate( &self, - group_expr: &[Expr], - aggr_expr: &[Expr], + group_expr: Vec, + aggr_expr: Vec, ) -> Result>; /// Limit the number of rows returned from this DataFrame. @@ -155,11 +155,11 @@ pub trait DataFrame: Send + Sync { /// # fn main() -> Result<()> { /// let mut ctx = ExecutionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; - /// let df = df.sort(&[col("a").sort(true, true), col("b").sort(false, false)])?; + /// let df = df.sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; /// # Ok(()) /// # } /// ``` - fn sort(&self, expr: &[Expr]) -> Result>; + fn sort(&self, expr: Vec) -> Result>; /// Join this DataFrame with another DataFrame using the specified columns as join keys /// @@ -171,7 +171,7 @@ pub trait DataFrame: Send + Sync { /// let mut ctx = ExecutionContext::new(); /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())? - /// .select(&[ + /// .select(vec![ /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index f0902a995a1..01f4204dbae 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -82,7 +82,7 @@ use parquet::file::properties::WriterProperties; /// let mut ctx = ExecutionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; /// let df = df.filter(col("a").lt_eq(col("b")))? -/// .aggregate(&[col("a")], &[min(col("b"))])? +/// .aggregate(vec![col("a")], vec![min(col("b"))])? /// .limit(100)?; /// let results = df.collect(); /// # Ok(()) @@ -954,7 +954,7 @@ mod tests { let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan()) - .project(&[col("c2")])? + .project(vec![col("c2")])? .build()?; let optimized_plan = ctx.optimize(&logical_plan)?; @@ -999,7 +999,7 @@ mod tests { assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); let plan = LogicalPlanBuilder::scan_empty("", &schema, None)? - .project(&[col("c1")])? + .project(vec![col("c1")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -1030,7 +1030,7 @@ mod tests { )?]]; let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? - .project(&[col("b")])? + .project(vec![col("b")])? .build()?; assert_fields_eq(&plan, vec!["b"]); @@ -1660,8 +1660,8 @@ mod tests { ])); let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)? - .aggregate(&[col("c1")], &[sum(col("c2"))])? - .project(&[col("c1"), col("SUM(c2)").alias("total_salary")])? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -1886,7 +1886,7 @@ mod tests { let t = ctx.table("t")?; let plan = LogicalPlanBuilder::from(&t.to_logical_plan()) - .project(&[ + .project(vec![ col("a"), col("b"), ctx.udf("my_add")?.call(vec![col("a"), col("b")]), diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 7c3beb954d4..62c18ebc985 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -58,11 +58,11 @@ impl DataFrame for DataFrameImpl { .map(|name| self.plan.schema().field_with_unqualified_name(name)) .collect::>>()?; let expr: Vec = fields.iter().map(|f| col(f.name())).collect(); - self.select(&expr) + self.select(expr) } /// Create a projection based on arbitrary expressions - fn select(&self, expr_list: &[Expr]) -> Result> { + fn select(&self, expr_list: Vec) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) .project(expr_list)? .build()?; @@ -80,8 +80,8 @@ impl DataFrame for DataFrameImpl { /// Perform an aggregate query fn aggregate( &self, - group_expr: &[Expr], - aggr_expr: &[Expr], + group_expr: Vec, + aggr_expr: Vec, ) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) .aggregate(group_expr, aggr_expr)? @@ -96,7 +96,7 @@ impl DataFrame for DataFrameImpl { } /// Sort by specified sorting expressions - fn sort(&self, expr: &[Expr]) -> Result> { + fn sort(&self, expr: Vec) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan).sort(expr)?.build()?; Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } @@ -204,7 +204,7 @@ mod tests { fn select_expr() -> Result<()> { // build plan using Table API let t = test_table()?; - let t2 = t.select(&[col("c1"), col("c2"), col("c11")])?; + let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?; let plan = t2.to_logical_plan(); // build query using SQL @@ -220,8 +220,8 @@ mod tests { fn aggregate() -> Result<()> { // build plan using DataFrame API let df = test_table()?; - let group_expr = &[col("c1")]; - let aggr_expr = &[ + let group_expr = vec![col("c1")]; + let aggr_expr = vec![ min(col("c12")), max(col("c12")), avg(col("c12")), @@ -322,7 +322,7 @@ mod tests { let f = df.registry(); - let df = df.select(&[f.udf("my_fn")?.call(vec![col("c12")])])?; + let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?; let plan = df.to_logical_plan(); // build query using SQL diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index c419dee88b0..3e1e1e2b126 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -48,7 +48,7 @@ //! //! // create a plan //! let df = df.filter(col("a").lt_eq(col("b")))? -//! .aggregate(&[col("a")], &[min(col("b"))])? +//! .aggregate(vec![col("a")], vec![min(col("b"))])? //! .limit(100)?; //! //! // execute the plan diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index a89e797c7b6..aa0380e071f 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -39,6 +39,43 @@ use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning}; use std::collections::HashSet; /// Builder for logical plans +/// +/// ``` +/// # use datafusion::prelude::*; +/// # use datafusion::logical_plan::LogicalPlanBuilder; +/// # use datafusion::error::Result; +/// # use arrow::datatypes::{Schema, DataType, Field}; +/// # +/// # fn main() -> Result<()> { +/// # +/// # fn employee_schema() -> Schema { +/// # Schema::new(vec![ +/// # Field::new("id", DataType::Int32, false), +/// # Field::new("first_name", DataType::Utf8, false), +/// # Field::new("last_name", DataType::Utf8, false), +/// # Field::new("state", DataType::Utf8, false), +/// # Field::new("salary", DataType::Int32, false), +/// # ]) +/// # } +/// # +/// // Create a plan similar to +/// // SELECT last_name +/// // FROM employees +/// // WHERE salary < 1000 +/// let plan = LogicalPlanBuilder::scan_empty( +/// "employee.csv", +/// &employee_schema(), +/// None, +/// )? +/// // Keep only rows where salary < 1000 +/// .filter(col("salary").lt_eq(lit(1000)))? +/// // only show "last_name" in the final results +/// .project(vec![col("last_name")])? +/// .build()?; +/// +/// # Ok(()) +/// # } +/// ``` pub struct LogicalPlanBuilder { plan: LogicalPlan, } @@ -132,18 +169,21 @@ impl LogicalPlanBuilder { /// This function errors under any of the following conditions: /// * Two or more expressions have the same name /// * An invalid expression is used (e.g. a `sort` expression) - pub fn project(&self, expr: &[Expr]) -> Result { + pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); let mut projected_expr = vec![]; - (0..expr.len()).for_each(|i| match &expr[i] { - Expr::Wildcard => { - (0..input_schema.fields().len()) - .for_each(|i| projected_expr.push(col(input_schema.field(i).name()))); - } - _ => projected_expr.push(expr[i].clone()), - }); + for e in expr { + match e { + Expr::Wildcard => { + (0..input_schema.fields().len()).for_each(|i| { + projected_expr.push(col(input_schema.field(i).name())) + }); + } + _ => projected_expr.push(e), + }; + } - validate_unique_names("Projections", &projected_expr, input_schema)?; + validate_unique_names("Projections", projected_expr.iter(), input_schema)?; let schema = DFSchema::new(exprlist_to_fields(&projected_expr, input_schema)?)?; @@ -171,9 +211,9 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, expr: &[Expr]) -> Result { + pub fn sort(&self, expr: impl IntoIterator) -> Result { Ok(Self::from(&LogicalPlan::Sort { - expr: expr.to_vec(), + expr: expr.into_iter().collect(), input: Arc::new(self.plan.clone()), })) } @@ -243,20 +283,28 @@ impl LogicalPlanBuilder { })) } - /// Apply an aggregate - pub fn aggregate(&self, group_expr: &[Expr], aggr_expr: &[Expr]) -> Result { - let mut all_expr = group_expr.to_vec(); - all_expr.extend_from_slice(aggr_expr); + /// Apply an aggregate: grouping on the `group_expr` expressions + /// and calculating `aggr_expr` aggregates for each distinct + /// value of the `group_expr`; + pub fn aggregate( + &self, + group_expr: impl IntoIterator, + aggr_expr: impl IntoIterator, + ) -> Result { + let group_expr = group_expr.into_iter().collect::>(); + let aggr_expr = aggr_expr.into_iter().collect::>(); + + let all_expr = group_expr.iter().chain(aggr_expr.iter()); - validate_unique_names("Aggregations", &all_expr, self.plan.schema())?; + validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; let aggr_schema = - DFSchema::new(exprlist_to_fields(&all_expr, self.plan.schema())?)?; + DFSchema::new(exprlist_to_fields(all_expr, self.plan.schema())?)?; Ok(Self::from(&LogicalPlan::Aggregate { input: Arc::new(self.plan.clone()), - group_expr: group_expr.to_vec(), - aggr_expr: aggr_expr.to_vec(), + group_expr: group_expr, + aggr_expr: aggr_expr, schema: DFSchemaRef::new(aggr_schema), })) } @@ -334,13 +382,13 @@ fn build_join_schema( } /// Errors if one or more expressions have equal names. -fn validate_unique_names( +fn validate_unique_names<'a>( node_name: &str, - expressions: &[Expr], + expressions: impl IntoIterator, input_schema: &DFSchema, ) -> Result<()> { let mut unique_names = HashMap::new(); - expressions.iter().enumerate().try_for_each(|(position, expr)| { + expressions.into_iter().enumerate().try_for_each(|(position, expr)| { let name = expr.name(input_schema)?; match unique_names.get(&name) { None => { @@ -375,7 +423,7 @@ mod tests { Some(vec![0, 3]), )? .filter(col("state").eq(lit("CO")))? - .project(&[col("id")])? + .project(vec![col("id")])? .build()?; let expected = "Projection: #id\ @@ -394,8 +442,11 @@ mod tests { &employee_schema(), Some(vec![3, 4]), )? - .aggregate(&[col("state")], &[sum(col("salary")).alias("total_salary")])? - .project(&[col("state"), col("total_salary")])? + .aggregate( + vec![col("state")], + vec![sum(col("salary")).alias("total_salary")], + )? + .project(vec![col("state"), col("total_salary")])? .build()?; let expected = "Projection: #state, #total_salary\ @@ -414,7 +465,7 @@ mod tests { &employee_schema(), Some(vec![3, 4]), )? - .sort(&[ + .sort(vec![ Expr::Sort { expr: Box::new(col("state")), asc: true, @@ -470,7 +521,7 @@ mod tests { Some(vec![0, 3]), )? // two columns with the same name => error - .project(&[col("id"), col("first_name").alias("id")]); + .project(vec![col("id"), col("first_name").alias("id")]); match plan { Err(DataFusionError::Plan(e)) => { @@ -496,7 +547,7 @@ mod tests { Some(vec![0, 3]), )? // two columns with the same name => error - .aggregate(&[col("state")], &[sum(col("salary")).alias("state")]); + .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); match plan { Err(DataFusionError::Plan(e)) => { diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 1eaa02b1e41..314f5d477b3 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1370,11 +1370,11 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } /// Create field meta-data from an expression, for use in a result set schema -pub fn exprlist_to_fields( - expr: &[Expr], +pub fn exprlist_to_fields<'a>( + expr: impl IntoIterator, input_schema: &DFSchema, ) -> Result> { - expr.iter().map(|e| e.to_field(input_schema)).collect() + expr.into_iter().map(|e| e.to_field(input_schema)).collect() } #[cfg(test)] diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index 00d25fb0f3f..a698b26ea66 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -788,7 +788,7 @@ mod tests { .unwrap() .filter(col("state").eq(lit("CO"))) .unwrap() - .project(&[col("id")]) + .project(vec![col("id")]) .unwrap() .build() .unwrap() @@ -1089,7 +1089,7 @@ mod tests { .unwrap() .filter(col("state").eq(lit("CO"))) .unwrap() - .project(&[col("id")]) + .project(vec![col("id")]) .unwrap() .build() .unwrap() diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index ec4dfd4b011..2fa03eb5c70 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -469,7 +469,7 @@ mod tests { let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("b").eq(lit(true)))? .filter(col("c").eq(lit(false)))? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let expected = "\ @@ -489,7 +489,7 @@ mod tests { .filter(col("b").not_eq(lit(true)))? .filter(col("c").not_eq(lit(false)))? .limit(1)? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let expected = "\ @@ -508,7 +508,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let expected = "\ @@ -525,7 +525,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let expected = "\ @@ -542,7 +542,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("b").eq(lit(false)).not())? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let expected = "\ @@ -558,7 +558,7 @@ mod tests { fn optimize_plan_support_projection() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("d"), col("b").eq(lit(false))])? + .project(vec![col("a"), col("d"), col("b").eq(lit(false))])? .build()?; let expected = "\ @@ -573,10 +573,10 @@ mod tests { fn optimize_plan_support_aggregate() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("c"), col("b")])? + .project(vec![col("a"), col("c"), col("b")])? .aggregate( - &[col("a"), col("c")], - &[max(col("b").eq(lit(true))), min(col("b"))], + vec![col("a"), col("c")], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; diff --git a/rust/datafusion/src/optimizer/filter_push_down.rs b/rust/datafusion/src/optimizer/filter_push_down.rs index 0ae8e06015d..ec260a41dc5 100644 --- a/rust/datafusion/src/optimizer/filter_push_down.rs +++ b/rust/datafusion/src/optimizer/filter_push_down.rs @@ -451,7 +451,7 @@ mod tests { fn filter_before_projection() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before projection @@ -467,7 +467,7 @@ mod tests { fn filter_after_limit() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .limit(10)? .filter(col("a").eq(lit(1i64)))? .build()?; @@ -485,8 +485,8 @@ mod tests { fn filter_jump_2_plans() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b"), col("c")])? - .project(&[col("c"), col("b")])? + .project(vec![col("a"), col("b"), col("c")])? + .project(vec![col("c"), col("b")])? .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before double projection @@ -503,7 +503,7 @@ mod tests { fn filter_move_agg() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[col("a")], &[sum(col("b")).alias("total_salary")])? + .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])? .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative @@ -519,7 +519,7 @@ mod tests { fn filter_keep_agg() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[col("a")], &[sum(col("b")).alias("b")])? + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative @@ -536,7 +536,7 @@ mod tests { fn alias() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a").alias("b"), col("c")])? + .project(vec![col("a").alias("b"), col("c")])? .filter(col("b").eq(lit(1i64)))? .build()?; // filter is before projection @@ -569,7 +569,7 @@ mod tests { fn complex_expression() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[ + .project(vec![ add(multiply(col("a"), lit(2)), col("c")).alias("b"), col("c"), ])? @@ -599,12 +599,12 @@ mod tests { fn complex_plan() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[ + .project(vec![ add(multiply(col("a"), lit(2)), col("c")).alias("b"), col("c"), ])? // second projection where we rename columns, just to make it difficult - .project(&[multiply(col("b"), lit(3)).alias("a"), col("c")])? + .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])? .filter(col("a").eq(lit(1i64)))? .build()?; @@ -635,8 +635,8 @@ mod tests { // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c)) let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a").alias("b"), col("c")])? - .aggregate(&[col("b")], &[sum(col("c"))])? + .project(vec![col("a").alias("b"), col("c")])? + .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(col("b").gt(lit(10i64)))? .filter(col("SUM(c)").gt(lit(10i64)))? .build()?; @@ -671,8 +671,8 @@ mod tests { // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c)) let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a").alias("b"), col("c")])? - .aggregate(&[col("b")], &[sum(col("c"))])? + .project(vec![col("a").alias("b"), col("c")])? + .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(and( col("SUM(c)").gt(lit(10i64)), and(col("b").gt(lit(10i64)), col("SUM(c)").lt(lit(20i64))), @@ -706,10 +706,10 @@ mod tests { fn double_limit() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .limit(20)? .limit(10)? - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .filter(col("a").eq(lit(1i64)))? .build()?; // filter does not just any of the limits @@ -729,10 +729,10 @@ mod tests { fn filter_2_breaks_limits() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a")])? + .project(vec![col("a")])? .filter(col("a").lt_eq(lit(1i64)))? .limit(1)? - .project(&[col("a")])? + .project(vec![col("a")])? .filter(col("a").gt_eq(lit(1i64)))? .build()?; // Should be able to move both filters below the projections @@ -768,7 +768,7 @@ mod tests { .limit(1)? .filter(col("a").lt_eq(lit(1i64)))? .filter(col("a").gt_eq(lit(1i64)))? - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; // not part of the test @@ -820,7 +820,7 @@ mod tests { let table_scan = test_table_scan()?; let left = LogicalPlanBuilder::from(&table_scan).build()?; let right = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a")])? + .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(&left) .join(&right, JoinType::Inner, &["a"], &["a"])? @@ -855,10 +855,10 @@ mod tests { fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; let left = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("c")])? + .project(vec![col("a"), col("c")])? .build()?; let right = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(&left) .join(&right, JoinType::Inner, &["a"], &["a"])? @@ -889,10 +889,10 @@ mod tests { fn filter_join_on_one_side() -> Result<()> { let table_scan = test_table_scan()?; let left = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .build()?; let right = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("c")])? + .project(vec![col("a"), col("c")])? .build()?; let plan = LogicalPlanBuilder::from(&left) .join(&right, JoinType::Inner, &["a"], &["a"])? diff --git a/rust/datafusion/src/optimizer/limit_push_down.rs b/rust/datafusion/src/optimizer/limit_push_down.rs index fee03988c06..73a231f2248 100644 --- a/rust/datafusion/src/optimizer/limit_push_down.rs +++ b/rust/datafusion/src/optimizer/limit_push_down.rs @@ -153,7 +153,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a")])? + .project(vec![col("a")])? .limit(1000)? .build()?; @@ -193,7 +193,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[col("a")], &[max(col("b"))])? + .aggregate(vec![col("a")], vec![max(col("b"))])? .limit(1000)? .build()?; @@ -235,7 +235,7 @@ mod test { let plan = LogicalPlanBuilder::from(&table_scan) .limit(1000)? - .aggregate(&[col("a")], &[max(col("b"))])? + .aggregate(vec![col("a")], vec![max(col("b"))])? .limit(10)? .build()?; diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 84523217574..6b1cdfe18ca 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -303,7 +303,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[], &[max(col("b"))])? + .aggregate(vec![], vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ @@ -319,7 +319,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[col("c")], &[max(col("b"))])? + .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\ @@ -336,7 +336,7 @@ mod tests { let plan = LogicalPlanBuilder::from(&table_scan) .filter(col("c"))? - .aggregate(&[], &[max(col("b"))])? + .aggregate(vec![], vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ @@ -353,7 +353,7 @@ mod tests { let table_scan = test_table_scan()?; let projection = LogicalPlanBuilder::from(&table_scan) - .project(&[Expr::Cast { + .project(vec![Expr::Cast { expr: Box::new(col("c")), data_type: DataType::Float64, }])? @@ -374,7 +374,7 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("a"), col("b")])? + .project(vec![col("a"), col("b")])? .build()?; assert_fields_eq(&plan, vec!["a", "b"]); @@ -394,7 +394,7 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("c"), col("a")])? + .project(vec![col("c"), col("a")])? .limit(5)? .build()?; @@ -423,7 +423,7 @@ mod tests { fn table_scan_with_literal_projection() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[lit(1_i64), lit(2_i64)])? + .project(vec![lit(1_i64), lit(2_i64)])? .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=Some([0])"; @@ -440,9 +440,9 @@ mod tests { // we never use "b" in the first projection => remove it let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("c"), col("a"), col("b")])? + .project(vec![col("c"), col("a"), col("b")])? .filter(col("c").gt(lit(1)))? - .aggregate(&[col("c")], &[max(col("a"))])? + .aggregate(vec![col("c")], vec![max(col("a"))])? .build()?; assert_fields_eq(&plan, vec!["c", "MAX(a)"]); @@ -467,8 +467,8 @@ mod tests { // there is no need for the first projection let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("b")])? - .project(&[lit(1).alias("a")])? + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? .build()?; assert_fields_eq(&plan, vec!["a"]); @@ -488,8 +488,8 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .project(&[col("b")])? - .project(&[lit(1).alias("a")])? + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? .build()?; let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); @@ -511,9 +511,9 @@ mod tests { // we never use "min(b)" => remove it let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(&[col("a"), col("c")], &[max(col("b")), min(col("b"))])? + .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? .filter(col("c").gt(lit(1)))? - .project(&[col("c"), col("a"), col("MAX(b)")])? + .project(vec![col("c"), col("a"), col("MAX(b)")])? .build()?; assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]); diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 81500b3522c..a9b1da179ca 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -787,9 +787,9 @@ mod tests { let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? // filter clause needs the type coercion rule applied .filter(col("c7").lt(lit(5_u8)))? - .project(&[col("c1"), col("c2")])? - .aggregate(&[col("c1")], &[sum(col("c2"))])? - .sort(&[col("c1").sort(true, true)])? + .project(vec![col("c1"), col("c2")])? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .sort(vec![col("c1").sort(true, true)])? .limit(10)? .build()?; @@ -860,7 +860,7 @@ mod tests { ]; for case in cases { let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - .project(&[case.clone()]); + .project(vec![case.clone()]); let message = format!( "Expression {:?} expected to error due to impossible coercion", case @@ -951,7 +951,7 @@ mod tests { let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? - .project(&[col("c1").in_list(list, false)])? + .project(vec![col("c1").in_list(list, false)])? .build()?; let execution_plan = plan(&logical_plan)?; // verify that the plan correctly adds cast from Int64(1) to Utf8 @@ -966,7 +966,7 @@ mod tests { let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? - .project(&[col("c12").lt_eq(lit(0.025)).in_list(list, false)])? + .project(vec![col("c12").lt_eq(lit(0.025)).in_list(list, false)])? .build()?; let execution_plan = plan(&logical_plan); @@ -991,7 +991,7 @@ mod tests { let options = CsvReadOptions::new().schema_infer_max_records(100); let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - .aggregate(&[col("c1")], &[sum(col("c2"))])? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .build()?; let execution_plan = plan(&logical_plan)?; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 45ad6891866..d34fc0474b0 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -528,7 +528,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &select_exprs, &having_expr_opt, &select.group_by, - &aggr_exprs, + aggr_exprs, )? } else { if let Some(having_expr) = &having_expr_opt { @@ -561,7 +561,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - self.project(&plan, &select_exprs_post_aggr, false) + self.project(&plan, select_exprs_post_aggr, false) } /// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions. @@ -592,7 +592,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn project( &self, input: &LogicalPlan, - expr: &[Expr], + expr: Vec, force: bool, ) -> Result { self.validate_schema_satisfies_exprs(&input.schema(), &expr)?; @@ -617,7 +617,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs: &[Expr], having_expr_opt: &Option, group_by: &[SQLExpr], - aggr_exprs: &[Expr], + aggr_exprs: Vec, ) -> Result<(LogicalPlan, Vec, Option)> { let group_by_exprs = group_by .iter() @@ -631,7 +631,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>(); let plan = LogicalPlanBuilder::from(&input) - .aggregate(&group_by_exprs, aggr_exprs)? + .aggregate(group_by_exprs, aggr_exprs)? .build()?; // After aggregation, these are all of the columns that will be @@ -718,9 +718,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect(); - LogicalPlanBuilder::from(&plan) - .sort(&order_by_rex?)? - .build() + LogicalPlanBuilder::from(&plan).sort(order_by_rex?)?.build() } /// Validate the schema provides all of the columns referenced in the expressions. diff --git a/rust/datafusion/tests/custom_sources.rs b/rust/datafusion/tests/custom_sources.rs index 0bc699ceffb..a00dd6ac282 100644 --- a/rust/datafusion/tests/custom_sources.rs +++ b/rust/datafusion/tests/custom_sources.rs @@ -162,7 +162,7 @@ async fn custom_source_dataframe() -> Result<()> { let table = ctx.read_table(Arc::new(CustomTableProvider))?; let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan()) - .project(&[col("c2")])? + .project(vec![col("c2")])? .build()?; let optimized_plan = ctx.optimize(&logical_plan)?; diff --git a/rust/datafusion/tests/provider_filter_pushdown.rs b/rust/datafusion/tests/provider_filter_pushdown.rs index f38ac59341e..0bf67bea8b9 100644 --- a/rust/datafusion/tests/provider_filter_pushdown.rs +++ b/rust/datafusion/tests/provider_filter_pushdown.rs @@ -150,7 +150,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() let df = ctx .read_table(Arc::new(provider.clone()))? .filter(col("flag").eq(lit(value)))? - .aggregate(&[], &[count(col("flag"))])?; + .aggregate(vec![], vec![count(col("flag"))])?; let results = df.collect().await?; let result_col: &UInt64Array = as_primitive_array(results[0].column(0));