Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:implement sql style 'substr_index' string function #8272

Merged
merged 14 commits into from
Nov 26, 2023
15 changes: 15 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction {
OverLay,
/// levenshtein
Levenshtein,
/// substr_index
SubstrIndex,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -470,6 +472,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
BuiltinScalarFunction::SubstrIndex => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -773,6 +776,9 @@ impl BuiltinScalarFunction {
return plan_err!("The to_hex function can only accept integers.");
}
}),
BuiltinScalarFunction::SubstrIndex => {
utf8_to_str_type(&input_expr_types[0], "substr_index")
}
BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] {
Int64 => Timestamp(Second, None),
_ => Timestamp(Nanosecond, None),
Expand Down Expand Up @@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction {
self.volatility(),
),

BuiltinScalarFunction::SubstrIndex => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),

BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
}
Expand Down Expand Up @@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],
BuiltinScalarFunction::Levenshtein => &["levenshtein"],
BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ scalar_expr!(

scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");
scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter");

scalar_expr!(
Struct,
Expand Down Expand Up @@ -1203,6 +1204,7 @@ mod test {
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count);
}

#[test]
Expand Down
23 changes: 23 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,29 @@ pub fn create_physical_fun(
))),
})
}
BuiltinScalarFunction::SubstrIndex => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
substr_index,
i32,
"substr_index"
);
make_scalar_function(func)(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
substr_index,
i64,
"substr_index"
);
make_scalar_function(func)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function substr_index",
))),
})
}
})
}

Expand Down
65 changes: 65 additions & 0 deletions datafusion/physical-expr/src/unicode_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,68 @@ pub fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

Ok(Arc::new(result) as ArrayRef)
}

/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return internal_err!(
"substr_index was called with {} arguments. It requires 3.",
args.len()
);
}

let string_array = as_generic_string_array::<T>(&args[0])?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add a defense check args is exactly 3 elements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I add the args len check

let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let count_array = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(delimiter_array.iter())
.zip(count_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let mut res = String::new();
match n {
0 => {
"".to_string();
}
_other => {
if n > 0 {
let idx = string
.split(delimiter)
.take(n as usize)
.fold(0, |len, x| len + x.len() + delimiter.len())
- delimiter.len();
res.push_str(if idx >= string.len() {
string
} else {
&string[..idx]
});
} else {
let idx = (string.split(delimiter).take((-n) as usize).fold(
string.len() as isize,
|len, x| {
len - x.len() as isize - delimiter.len() as isize
},
) + delimiter.len() as isize)
as usize;
res.push_str(if idx >= string.len() {
string
} else {
&string[idx..]
});
}
}
}
Some(res)
}
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ enum ScalarFunction {
ArrayExcept = 123;
ArrayPopFront = 124;
Levenshtein = 125;
SubstrIndex = 126;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ use datafusion_expr::{
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index,
substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis,
to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid,
window_frame::regularize,
AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction,
Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
Expand Down Expand Up @@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
ScalarFunction::Levenshtein => Self::Levenshtein,
ScalarFunction::SubstrIndex => Self::SubstrIndex,
}
}
}
Expand Down Expand Up @@ -1716,6 +1717,11 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::SubstrIndex => Ok(substr_index(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex,
};

Ok(scalar_function)
Expand Down
75 changes: 75 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,78 @@ query ?
SELECT levenshtein(NULL, NULL)
----
NULL

query T
SELECT substr_index('www.apache.org', '.', 1)
----
www

query T
SELECT substr_index('www.apache.org', '.', 2)
----
www.apache

query T
SELECT substr_index('www.apache.org', '.', -1)
----
org

query T
SELECT substr_index('www.apache.org', '.', -2)
----
apache.org

query T
SELECT substr_index('www.apache.org', 'ac', 1)
----
www.ap

query T
SELECT substr_index('www.apache.org', 'ac', -1)
----
he.org

query T
SELECT substr_index('www.apache.org', 'ac', 2)
----
www.apache.org

query T
SELECT substr_index('www.apache.org', 'ac', -2)
----
www.apache.org

query ?
SELECT substr_index(NULL, 'ac', 1)
----
NULL

query T
SELECT substr_index('www.apache.org', NULL, 1)
----
NULL

query T
SELECT substr_index('www.apache.org', 'ac', NULL)
----
NULL

query T
SELECT substr_index('', 'ac', 1)
----
(empty)

query T
SELECT substr_index('www.apache.org', '', 1)
----
(empty)

query T
SELECT substr_index('www.apache.org', 'ac', 0)
----
(empty)

query ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, can we also have the same tests with empty strings as input and search token?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add empty string tests and 0 count tests

SELECT substr_index(NULL, NULL, NULL)
----
NULL
18 changes: 18 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ nullif(expression1, expression2)
- [uuid](#uuid)
- [overlay](#overlay)
- [levenshtein](#levenshtein)
- [substr_index](#substr_index)

### `ascii`

Expand Down Expand Up @@ -1152,6 +1153,23 @@ levenshtein(str1, str2)
- **str1**: String expression to compute Levenshtein distance with str2.
- **str2**: String expression to compute Levenshtein distance with str1.

### `substr_index`

Returns the substring from str before count occurrences of the delimiter delim.
If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org`

```
substr_index(str, delim, count)
```

#### Arguments

- **str**: String expression to operate on.
- **delim**: the string to find in str to split str.
- **count**: The number of times to search for the delimiter. Can be both a positive or negative number.

## Binary String Functions

- [decode](#decode)
Expand Down