From 6f86bfad2fa12478c29eaa14355d0801b4ebf489 Mon Sep 17 00:00:00 2001 From: Arttu Date: Sat, 6 Jul 2024 13:27:18 +0200 Subject: [PATCH] feat: enable "substring" as a UDF in addition to "substr" (#11277) * feat: enable "substring" as a UDF in addition to "substr" Substrait uses the name "substring", and it already exists in DF SQL The setup here is a bit weird; I'd have added substring as an alias for substr, but then we have here this "substring" version being created as udf already and exported through the export_functions, with slightly different args than substr (even though in reality the underlying function for both is the same substr impl). I think this PR should work, but if you have suggestions on how to make the situation here cleaner, I'd be happy to! * okay redo everything: add an alias instead, and add renaming in the substrait producer * add alias into scalar_functions.md --- datafusion/functions/src/unicode/substr.rs | 6 ++ .../substrait/src/logical_plan/consumer.rs | 60 +++++++++---------- .../substrait/src/logical_plan/producer.rs | 32 ++++++---- .../tests/cases/roundtrip_logical_plan.rs | 2 +- .../source/user-guide/sql/scalar_functions.md | 8 +++ 5 files changed, 65 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index c297182057fe..9d15920bb655 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -32,6 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub struct SubstrFunc { signature: Signature, + aliases: Vec, } impl Default for SubstrFunc { @@ -53,6 +54,7 @@ impl SubstrFunc { ], Volatility::Immutable, ), + aliases: vec![String::from("substring")], } } } @@ -81,6 +83,10 @@ impl ScalarUDFImpl for SubstrFunc { other => exec_err!("Unsupported data type {other:?} for function substr"), } } + + fn aliases(&self) -> &[String] { + &self.aliases + } } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index cc10ea0619c1..c65943643e8c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -88,36 +88,36 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; -pub fn name_to_op(name: &str) -> Result { +pub fn name_to_op(name: &str) -> Option { match name { - "equal" => Ok(Operator::Eq), - "not_equal" => Ok(Operator::NotEq), - "lt" => Ok(Operator::Lt), - "lte" => Ok(Operator::LtEq), - "gt" => Ok(Operator::Gt), - "gte" => Ok(Operator::GtEq), - "add" => Ok(Operator::Plus), - "subtract" => Ok(Operator::Minus), - "multiply" => Ok(Operator::Multiply), - "divide" => Ok(Operator::Divide), - "mod" => Ok(Operator::Modulo), - "and" => Ok(Operator::And), - "or" => Ok(Operator::Or), - "is_distinct_from" => Ok(Operator::IsDistinctFrom), - "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), - "regex_match" => Ok(Operator::RegexMatch), - "regex_imatch" => Ok(Operator::RegexIMatch), - "regex_not_match" => Ok(Operator::RegexNotMatch), - "regex_not_imatch" => Ok(Operator::RegexNotIMatch), - "bitwise_and" => Ok(Operator::BitwiseAnd), - "bitwise_or" => Ok(Operator::BitwiseOr), - "str_concat" => Ok(Operator::StringConcat), - "at_arrow" => Ok(Operator::AtArrow), - "arrow_at" => Ok(Operator::ArrowAt), - "bitwise_xor" => Ok(Operator::BitwiseXor), - "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), - _ => not_impl_err!("Unsupported function name: {name:?}"), + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, } } @@ -1124,7 +1124,7 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::ScalarFunction( expr::ScalarFunction::new_udf(func.to_owned(), args), ))) - } else if let Ok(op) = name_to_op(fn_name) { + } else if let Some(op) = name_to_op(fn_name) { if f.arguments.len() < 2 { return not_impl_err!( "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c3bef1689d14..899fec21f8bb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -818,7 +818,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -849,7 +849,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.name().to_string(), extension_info); + let function_anchor = register_function(fun.name().to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -917,7 +917,7 @@ fn to_substrait_sort_field( } } -fn _register_function( +fn register_function( function_name: String, extension_info: &mut ( Vec, @@ -926,6 +926,14 @@ fn _register_function( ) -> u32 { let (function_extensions, function_set) = extension_info; let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, // a plan-relative identifier starting from 0 is used as the function_anchor. // The consumer is responsible for correctly registering @@ -969,7 +977,7 @@ pub fn make_binary_op_scalar_func( ), ) -> Expression { let function_anchor = - _register_function(operator_to_name(op).to_string(), extension_info); + register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1044,7 +1052,7 @@ pub fn to_substrait_rex( if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1076,7 +1084,7 @@ pub fn to_substrait_rex( } let function_anchor = - _register_function(fun.name().to_string(), extension_info); + register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1252,7 +1260,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1330,7 +1338,7 @@ pub fn to_substrait_rex( }; if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1727,9 +1735,9 @@ fn make_substrait_like_expr( ), ) -> Result { let function_anchor = if ignore_case { - _register_function("ilike".to_string(), extension_info) + register_function("ilike".to_string(), extension_info) } else { - _register_function("like".to_string(), extension_info) + register_function("like".to_string(), extension_info) }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; @@ -1759,7 +1767,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = _register_function("not".to_string(), extension_info); + let function_anchor = register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2128,7 +2136,7 @@ fn to_substrait_unary_scalar_fn( HashMap, ), ) -> Result { - let function_anchor = _register_function(fn_name.to_string(), extension_info); + let function_anchor = register_function(fn_name.to_string(), extension_info); let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7ed376f62ba0..dbc2e404bf56 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -327,7 +327,7 @@ async fn simple_scalar_function_pow() -> Result<()> { #[tokio::test] async fn simple_scalar_function_substr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await + roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await } #[tokio::test] diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ec34dbf9ba6c..d636726b45fe 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1132,6 +1132,14 @@ substr(str, start_pos[, length]) - **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. +#### Aliases + +- substring + +### `substring` + +_Alias of [substr](#substr)._ + ### `translate` Translates characters in a string to specified translation characters.