diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 20d0067350c..a22b5d3c1e7 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -228,6 +228,12 @@ impl ExecutionContext { } /// Registers a scalar UDF within this context. + /// + /// Note in SQL queries, function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` + /// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` pub fn register_udf(&mut self, f: ScalarUDF) { self.state .lock() @@ -237,6 +243,12 @@ impl ExecutionContext { } /// Registers an aggregate UDF within this context. + /// + /// Note in SQL queries, aggregate names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"` + /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` pub fn register_udaf(&mut self, f: AggregateUDF) { self.state .lock() @@ -1709,6 +1721,167 @@ mod tests { Ok(()) } + #[tokio::test] + async fn case_sensitive_identifiers_functions() { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let expected = vec![ + "+---------+", + "| sqrt(i) |", + "+---------+", + "| 1 |", + "+---------+", + ]; + + let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t") + .await + .unwrap(); + + assert_batches_sorted_eq!(expected, &results); + + let results = plan_and_collect(&mut ctx, "SELECT SQRT(i) FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &results); + + // Using double quotes allows specifying the function name with capitalization + let err = plan_and_collect(&mut ctx, "SELECT \"SQRT\"(i) FROM t") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Invalid function 'SQRT'" + ); + + let results = plan_and_collect(&mut ctx, "SELECT \"sqrt\"(i) FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &results); + } + + #[tokio::test] + async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + ctx.register_udf(create_udf( + "MY_FUNC", + vec![DataType::Int32], + Arc::new(DataType::Int32), + myfunc, + )); + + // doesn't work as it was registered with non lowercase + let err = plan_and_collect(&mut ctx, "SELECT MY_FUNC(i) FROM t") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Invalid function \'my_func\'" + ); + + // Can call it if you put quotes + let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; + + let expected = vec![ + "+------------+", + "| MY_FUNC(i) |", + "+------------+", + "| 1 |", + "+------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn case_sensitive_identifiers_aggregates() { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let expected = vec![ + "+--------+", + "| MAX(i) |", + "+--------+", + "| 1 |", + "+--------+", + ]; + + let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t") + .await + .unwrap(); + + assert_batches_sorted_eq!(expected, &results); + + let results = plan_and_collect(&mut ctx, "SELECT MAX(i) FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &results); + + // Using double quotes allows specifying the function name with capitalization + let err = plan_and_collect(&mut ctx, "SELECT \"MAX\"(i) FROM t") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Invalid function 'MAX'" + ); + + let results = plan_and_collect(&mut ctx, "SELECT \"max\"(i) FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &results); + } + + #[tokio::test] + async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + // Note capitalizaton + let my_avg = create_udaf( + "MY_AVG", + DataType::Float64, + Arc::new(DataType::Float64), + Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + // doesn't work as it was registered as non lowercase + let err = plan_and_collect(&mut ctx, "SELECT MY_AVG(i) FROM t") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Invalid function \'my_avg\'" + ); + + // Can call it if you put quotes + let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; + + let expected = vec![ + "+-----------+", + "| MY_AVG(i) |", + "+-----------+", + "| 1 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + #[tokio::test] async fn write_csv_results() -> Result<()> { // create partitioned input file and context @@ -2035,7 +2208,7 @@ mod tests { // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( - "MY_AVG", + "my_avg", DataType::Float64, Arc::new(DataType::Float64), Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), @@ -2048,7 +2221,7 @@ mod tests { let expected = vec![ "+-----------+", - "| MY_AVG(a) |", + "| my_avg(a) |", "+-----------+", "| 3 |", "+-----------+", diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs index 59aa730fdea..be90daa954d 100644 --- a/rust/datafusion/src/physical_plan/aggregates.rs +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -72,12 +72,12 @@ impl fmt::Display for AggregateFunction { impl FromStr for AggregateFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - Ok(match &*name.to_uppercase() { - "MIN" => AggregateFunction::Min, - "MAX" => AggregateFunction::Max, - "COUNT" => AggregateFunction::Count, - "AVG" => AggregateFunction::Avg, - "SUM" => AggregateFunction::Sum, + Ok(match name { + "min" => AggregateFunction::Min, + "max" => AggregateFunction::Max, + "count" => AggregateFunction::Count, + "avg" => AggregateFunction::Avg, + "sum" => AggregateFunction::Sum, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 5d638a3e449..39fb305b4a2 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -1028,7 +1028,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::Function(function) => { - let name: String = function.name.to_string(); + let name = if function.name.0.len() > 1 { + // DF doesn't handle compound identifiers + // (e.g. "foo.bar") for function names yet + function.name.to_string() + } else { + // if there is a quote style, then don't normalize + // the name, otherwise normalize to lowercase + let ident = &function.name.0[0]; + match ident.quote_style { + Some(_) => ident.value.clone(), + None => ident.value.to_ascii_lowercase(), + } + }; // first, scalar built-in if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) {