Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 175 additions & 2 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ impl ExecutionContext {
}

/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
pub fn register_udf(&mut self, f: ScalarUDF) {
self.state
.lock()
Expand All @@ -237,6 +243,12 @@ impl ExecutionContext {
}

/// Registers an aggregate UDF within this context.
///
/// Note in SQL queries, aggregate names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
pub fn register_udaf(&mut self, f: AggregateUDF) {
self.state
.lock()
Expand Down Expand Up @@ -1709,6 +1721,167 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_functions() {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
.unwrap();

let expected = vec![
"+---------+",
"| sqrt(i) |",
"+---------+",
"| 1 |",
"+---------+",
];

let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t")
.await
.unwrap();

assert_batches_sorted_eq!(expected, &results);

let results = plan_and_collect(&mut ctx, "SELECT SQRT(i) FROM t")
.await
.unwrap();
assert_batches_sorted_eq!(expected, &results);

// Using double quotes allows specifying the function name with capitalization
let err = plan_and_collect(&mut ctx, "SELECT \"SQRT\"(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function 'SQRT'"
);

let results = plan_and_collect(&mut ctx, "SELECT \"sqrt\"(i) FROM t")
.await
.unwrap();
assert_batches_sorted_eq!(expected, &results);
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
.unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);

ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
myfunc,
));

// doesn't work as it was registered with non lowercase
let err = plan_and_collect(&mut ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_func\'"
);

// Can call it if you put quotes
let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;

let expected = vec![
"+------------+",
"| MY_FUNC(i) |",
"+------------+",
"| 1 |",
"+------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_aggregates() {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
.unwrap();

let expected = vec![
"+--------+",
"| MAX(i) |",
"+--------+",
"| 1 |",
"+--------+",
];

let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t")
.await
.unwrap();

assert_batches_sorted_eq!(expected, &results);

let results = plan_and_collect(&mut ctx, "SELECT MAX(i) FROM t")
.await
.unwrap();
assert_batches_sorted_eq!(expected, &results);

// Using double quotes allows specifying the function name with capitalization
let err = plan_and_collect(&mut ctx, "SELECT \"MAX\"(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function 'MAX'"
);

let results = plan_and_collect(&mut ctx, "SELECT \"max\"(i) FROM t")
.await
.unwrap();
assert_batches_sorted_eq!(expected, &results);
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
.unwrap();

// Note capitalizaton
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg);

// doesn't work as it was registered as non lowercase
let err = plan_and_collect(&mut ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_avg\'"
);

// Can call it if you put quotes
let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;

let expected = vec![
"+-----------+",
"| MY_AVG(i) |",
"+-----------+",
"| 1 |",
"+-----------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

#[tokio::test]
async fn write_csv_results() -> Result<()> {
// create partitioned input file and context
Expand Down Expand Up @@ -2035,7 +2208,7 @@ mod tests {

// define a udaf, using a DataFusion's accumulator
let my_avg = create_udaf(
"MY_AVG",
"my_avg",
DataType::Float64,
Arc::new(DataType::Float64),
Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Expand All @@ -2048,7 +2221,7 @@ mod tests {

let expected = vec![
"+-----------+",
"| MY_AVG(a) |",
"| my_avg(a) |",
"+-----------+",
"| 3 |",
"+-----------+",
Expand Down
12 changes: 6 additions & 6 deletions rust/datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ impl fmt::Display for AggregateFunction {
impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match &*name.to_uppercase() {
"MIN" => AggregateFunction::Min,
"MAX" => AggregateFunction::Max,
"COUNT" => AggregateFunction::Count,
"AVG" => AggregateFunction::Avg,
"SUM" => AggregateFunction::Sum,
Ok(match name {
"min" => AggregateFunction::Min,
"max" => AggregateFunction::Max,
"count" => AggregateFunction::Count,
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down
14 changes: 13 additions & 1 deletion rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

SQLExpr::Function(function) => {
let name: String = function.name.to_string();
let name = if function.name.0.len() > 1 {
// DF doesn't handle compound identifiers
// (e.g. "foo.bar") for function names yet
function.name.to_string()
} else {
// if there is a quote style, then don't normalize
// the name, otherwise normalize to lowercase
let ident = &function.name.0[0];
match ident.quote_style {
Some(_) => ident.value.clone(),
None => ident.value.to_ascii_lowercase(),
}
};

// first, scalar built-in
if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) {
Expand Down