From 0b8a3ccbea1b89cc29f6e6ef22aa05ba1c721089 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 15:02:00 +0800 Subject: [PATCH 01/11] feat:implement sql style 'substr_index' string function --- datafusion/expr/src/built_in_function.rs | 15 +++++ datafusion/expr/src/expr_fn.rs | 2 + datafusion/physical-expr/src/functions.rs | 23 +++++++ .../physical-expr/src/unicode_expressions.rs | 45 +++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 12 +++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../sqllogictest/test_files/functions.slt | 66 +++++++++++++++---- .../source/user-guide/sql/scalar_functions.md | 18 +++++ 11 files changed, 175 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e9030ebcc00f..8893224ae9fc 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction { OverLay, /// levenshtein Levenshtein, + /// substr_index + SubstrIndex, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -470,6 +472,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -761,6 +764,9 @@ impl BuiltinScalarFunction { return plan_err!("The to_hex function can only accept integers."); } }), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] { Int64 => Timestamp(Second, None), _ => Timestamp(Nanosecond, None), @@ -1223,6 +1229,14 @@ impl BuiltinScalarFunction { self.volatility(), ), + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } @@ -1474,6 +1488,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 674d2a34df38..af80e8ec0fe8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -916,6 +916,7 @@ scalar_expr!( scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); scalar_expr!( Struct, @@ -1203,6 +1204,7 @@ mod test { test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5a1a68dd2127..40b21347edf5 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -862,6 +862,29 @@ pub fn create_physical_fun( ))), }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } }) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e28700a25ce4..09d5ed276a9d 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -455,3 +455,48 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + if n == 0 { + Some("".to_string()) + } else { + let mut start = 0; + let mut count = 0; + while let Some(idx) = string[start..].find(delimiter) { + count += if n > 0 { 1 } else { -1 }; + start += idx + delimiter.len(); + if count == n { + if n > 0 { + start -= delimiter.len(); + res.push_str(&string[0..start]); + } else { + res.push_str(&string[start..]); + } + break; + } + } + Some(res) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} \ No newline at end of file diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9d508078c705..94e96dcf6b80 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -641,6 +641,7 @@ enum ScalarFunction { ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; + SubstrIndex = 126; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0a8f415e20c5..b3ab2458ded5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20849,6 +20849,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", }; serializer.serialize_str(variant) } @@ -20986,6 +20987,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept", "ArrayPopFront", "Levenshtein", + "SubstrIndex", ]; struct GeneratedVisitor; @@ -21152,6 +21154,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 84fb84b9487e..dc0439a0aab3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2572,6 +2572,7 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, + SubstrIndex = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2706,6 +2707,7 @@ impl ScalarFunction { ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2837,6 +2839,7 @@ impl ScalarFunction { "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4ae45fa52162..1b8bf5fea1de 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,9 +55,9 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, - sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, - to_timestamp_seconds, translate, trim, trunc, upper, uuid, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, } } } @@ -1713,6 +1714,11 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cf66e3ddd5b5..653918a1fa0e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1566,6 +1566,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 9c8bb2c5f844..ca2f88e17a5b 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement ok +statement error DataFusion error: IO error: No such file or directory \(os error 2\) CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,22 +554,16 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(c11)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 ----- -0.658440848342 -statement ok +statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. drop table aggregate_test_100 @@ -877,3 +871,53 @@ query ? SELECT levenshtein(NULL, NULL) ---- NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index eda46ef8a73b..c88b4edbd5f4 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -637,6 +637,7 @@ nullif(expression1, expression2) - [uuid](#uuid) - [overlay](#overlay) - [levenshtein](#levenshtein) +- [substr_index](#substr_index) ### `ascii` @@ -1152,6 +1153,23 @@ levenshtein(str1, str2) - **str1**: String expression to compute Levenshtein distance with str2. - **str2**: String expression to compute Levenshtein distance with str1. +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = apache.org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + ## Binary String Functions - [decode](#decode) From 03b8d1535f0e776790a88978714d7703485fdd84 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 15:29:52 +0800 Subject: [PATCH 02/11] code format --- datafusion/physical-expr/src/unicode_expressions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 09d5ed276a9d..e09a70ee6637 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -499,4 +499,4 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { .collect::>(); Ok(Arc::new(result) as ArrayRef) -} \ No newline at end of file +} From eb206bde00f14fda1e626a120e22190628b85ac3 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 16:18:46 +0800 Subject: [PATCH 03/11] code format --- datafusion/sqllogictest/test_files/functions.slt | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index ca2f88e17a5b..e36b72d4a5de 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement error DataFusion error: IO error: No such file or directory \(os error 2\) +statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,16 +554,22 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(c11)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 +---- +0.658440848342 -statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. +statement ok drop table aggregate_test_100 From d5be3827014d3286c378f76e7305383b2193f9f0 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 17:16:12 +0800 Subject: [PATCH 04/11] code format --- .../physical-expr/src/unicode_expressions.rs | 31 ++++++++++--------- .../sqllogictest/test_files/functions.slt | 4 +-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e09a70ee6637..654cddf4650f 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -31,6 +31,7 @@ use datafusion_common::{ }; use hashbrown::HashMap; use std::cmp::{max, Ordering}; +use std::process::id; use std::sync::Arc; use unicode_segmentation::UnicodeSegmentation; @@ -476,21 +477,21 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { if n == 0 { Some("".to_string()) } else { - let mut start = 0; - let mut count = 0; - while let Some(idx) = string[start..].find(delimiter) { - count += if n > 0 { 1 } else { -1 }; - start += idx + delimiter.len(); - if count == n { - if n > 0 { - start -= delimiter.len(); - res.push_str(&string[0..start]); - } else { - res.push_str(&string[start..]); - } - break; - } - } + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(&string[..idx]) + } else { + let idx = string + .split(delimiter) + .take((-n) as usize) + .fold(string.len(), |len, x| len - x.len() - delimiter.len()) + + delimiter.len(); + res.push_str(&string[idx..]) + }; Some(res) } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e36b72d4a5de..c1c68ebca678 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -891,12 +891,12 @@ www.apache query T SELECT substr_index('www.apache.org', '.', -1) ---- -apache.org +org query T SELECT substr_index('www.apache.org', '.', -2) ---- -org +apache.org query T SELECT substr_index('www.apache.org', 'ac', 1) From 5fccd49823a2ff0ca61178935e291739ce4194bc Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 17:53:12 +0800 Subject: [PATCH 05/11] fix index bound issue --- .../physical-expr/src/unicode_expressions.rs | 26 ++++++++++++------- .../sqllogictest/test_files/functions.slt | 26 +++++++++++-------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 654cddf4650f..7f065a895d29 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -30,6 +30,7 @@ use datafusion_common::{ exec_err, internal_err, DataFusionError, Result, }; use hashbrown::HashMap; +use libc::socket; use std::cmp::{max, Ordering}; use std::process::id; use std::sync::Arc; @@ -475,25 +476,32 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { (Some(string), Some(delimiter), Some(n)) => { let mut res = String::new(); if n == 0 { - Some("".to_string()) + Some("".to_string()); } else { if n > 0 { - let idx = string + let mut idx = string .split(delimiter) .take(n as usize) .fold(0, |len, x| len + x.len() + delimiter.len()) - delimiter.len(); - res.push_str(&string[..idx]) + res.push_str(if idx < 0 { &string } else { &string[..idx] }); } else { - let idx = string + let mut idx = (string .split(delimiter) .take((-n) as usize) - .fold(string.len(), |len, x| len - x.len() - delimiter.len()) - + delimiter.len(); - res.push_str(&string[idx..]) - }; - Some(res) + .fold(string.len() as isize, |len, x| { + len - x.len() as isize - delimiter.len() as isize + }) + + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + &string + } else { + &string[idx..] + }); + } } + Some(res) } _ => None, }) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index c1c68ebca678..547e4eddd224 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement ok +statement error DataFusion error: IO error: No such file or directory \(os error 2\) CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,22 +554,16 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(c11)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 ----- -0.658440848342 -statement ok +statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. drop table aggregate_test_100 @@ -908,6 +902,16 @@ SELECT substr_index('www.apache.org', 'ac', -1) ---- he.org +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + query ? SELECT substr_index(NULL, 'ac', 1) ---- From 407cf4b0d67583f562819b8733d55782dd34e446 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 18:11:05 +0800 Subject: [PATCH 06/11] code format --- .../physical-expr/src/unicode_expressions.rs | 51 ++++++++++--------- .../sqllogictest/test_files/functions.slt | 16 ++++-- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 7f065a895d29..d568973cb8ee 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -30,9 +30,7 @@ use datafusion_common::{ exec_err, internal_err, DataFusionError, Result, }; use hashbrown::HashMap; -use libc::socket; use std::cmp::{max, Ordering}; -use std::process::id; use std::sync::Arc; use unicode_segmentation::UnicodeSegmentation; @@ -475,30 +473,33 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { .map(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { let mut res = String::new(); - if n == 0 { - Some("".to_string()); - } else { - if n > 0 { - let mut idx = string - .split(delimiter) - .take(n as usize) - .fold(0, |len, x| len + x.len() + delimiter.len()) - - delimiter.len(); - res.push_str(if idx < 0 { &string } else { &string[..idx] }); - } else { - let mut idx = (string - .split(delimiter) - .take((-n) as usize) - .fold(string.len() as isize, |len, x| { - len - x.len() as isize - delimiter.len() as isize - }) - + delimiter.len() as isize) - as usize; - res.push_str(if idx >= string.len() { - &string + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { string } else { &string[..idx] }); } else { - &string[idx..] - }); + let idx = (string + .split(delimiter) + .take((-n) as usize) + .fold(string.len() as isize, |len, x| { + len - x.len() as isize - delimiter.len() as isize + }) + + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } } } Some(res) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 547e4eddd224..d72f1b69a597 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement error DataFusion error: IO error: No such file or directory \(os error 2\) +statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,16 +554,22 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(c11)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 +---- +0.658440848342 -statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. +statement ok drop table aggregate_test_100 From b296f194b20de9474ae3f9fdf4df54f4af27b50b Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 20 Nov 2023 18:13:05 +0800 Subject: [PATCH 07/11] code format --- .../physical-expr/src/unicode_expressions.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index d568973cb8ee..76acc442560e 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -484,15 +484,18 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { .take(n as usize) .fold(0, |len, x| len + x.len() + delimiter.len()) - delimiter.len(); - res.push_str(if idx >= string.len() { string } else { &string[..idx] }); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); } else { - let idx = (string - .split(delimiter) - .take((-n) as usize) - .fold(string.len() as isize, |len, x| { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { len - x.len() as isize - delimiter.len() as isize - }) - + delimiter.len() as isize) + }, + ) + delimiter.len() as isize) as usize; res.push_str(if idx >= string.len() { string From b5286fa9bb332bbffcdda451c9dae6ecd9a9e917 Mon Sep 17 00:00:00 2001 From: cli2 Date: Tue, 21 Nov 2023 15:06:12 +0800 Subject: [PATCH 08/11] add args len check --- datafusion/physical-expr/src/unicode_expressions.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 76acc442560e..29328d72ce17 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -462,6 +462,13 @@ pub fn translate(args: &[ArrayRef]) -> Result { /// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Internal(format!( + "substr_index function requires three arguments, got {}", + args.len() + ))); + } + let string_array = as_generic_string_array::(&args[0])?; let delimiter_array = as_generic_string_array::(&args[1])?; let count_array = as_int64_array(&args[2])?; From c0ceb377fd4bd7f488045ac47e5dbc5122364682 Mon Sep 17 00:00:00 2001 From: cli2 Date: Tue, 21 Nov 2023 15:09:59 +0800 Subject: [PATCH 09/11] add sql tests --- datafusion/sqllogictest/test_files/functions.slt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index d72f1b69a597..91072a49cd46 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -933,6 +933,21 @@ SELECT substr_index('www.apache.org', 'ac', NULL) ---- NULL +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + query ? SELECT substr_index(NULL, NULL, NULL) ---- From ec5f723e50670f3c3f5e6c214ccc908bdfa7259c Mon Sep 17 00:00:00 2001 From: cli2 Date: Wed, 22 Nov 2023 15:39:19 +0800 Subject: [PATCH 10/11] code format --- datafusion/physical-expr/src/unicode_expressions.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 29328d72ce17..f27b3c157741 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -463,10 +463,10 @@ pub fn translate(args: &[ArrayRef]) -> Result { /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org pub fn substr_index(args: &[ArrayRef]) -> Result { if args.len() != 3 { - return Err(DataFusionError::Internal(format!( - "substr_index function requires three arguments, got {}", + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", args.len() - ))); + ); } let string_array = as_generic_string_array::(&args[0])?; From 5a349842e6a3cbbcef2a08f94c9fe5f5a142809d Mon Sep 17 00:00:00 2001 From: cli2 Date: Thu, 23 Nov 2023 17:01:41 +0800 Subject: [PATCH 11/11] doc format --- docs/source/user-guide/sql/scalar_functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c88b4edbd5f4..e7ebbc9f1fe7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1158,7 +1158,7 @@ levenshtein(str1, str2) Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = apache.org` +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` ``` substr_index(str, delim, count)