Skip to content

Commit

Permalink
Update SPLIT_PART scalar function to support Utf8View
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Aug 14, 2024
1 parent 69c99a7 commit 8af587e
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 37 deletions.
128 changes: 93 additions & 35 deletions datafusion/functions/src/string/split_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
Expand All @@ -46,7 +48,12 @@ impl SplitPartFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View, Int64]),
Exact(vec![Utf8View, Utf8, Int64]),
Exact(vec![Utf8View, LargeUtf8, Int64]),
Exact(vec![Utf8, Utf8View, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, Utf8View, Int64]),
Exact(vec![LargeUtf8, Utf8, Int64]),
Exact(vec![Utf8, LargeUtf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
Expand Down Expand Up @@ -75,50 +82,101 @@ impl ScalarUDFImpl for SplitPartFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(split_part::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(split_part::<i64>, vec![])(args),
other => {
exec_err!("Unsupported data type {other:?} for function split_part")
match (args[0].data_type(), args[1].data_type()) {
(
DataType::Utf8 | DataType::Utf8View,
DataType::Utf8 | DataType::Utf8View,
) => make_scalar_function(split_part::<i32, i32>, vec![])(args),
(DataType::LargeUtf8, DataType::LargeUtf8) => {
make_scalar_function(split_part::<i64, i64>, vec![])(args)
}
(_, DataType::LargeUtf8) => {
make_scalar_function(split_part::<i32, i64>, vec![])(args)
}
(DataType::LargeUtf8, _) => {
make_scalar_function(split_part::<i64, i32>, vec![])(args)
}
(first_type, second_type) => exec_err!(
"unsupported first type {} and second type {} for split_part function",
first_type,
second_type
),
}
}
}

macro_rules! process_split_part {
($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{
let result = $string_array
.iter()
.zip($delimiter_array.iter())
.zip($n_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let split_string: Vec<&str> = string.split(delimiter).collect();
let len = split_string.len();

let index = match n.cmp(&0) {
std::cmp::Ordering::Less => len as i64 + n,
std::cmp::Ordering::Equal => {
return exec_err!("field position must not be zero");
}
std::cmp::Ordering::Greater => n - 1,
} as usize;

if index < len {
Ok(Some(split_string[index]))
} else {
Ok(Some(""))
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringLen>>>()?;
Ok(Arc::new(result) as ArrayRef)
}};
}

/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
fn split_part<StringLen: OffsetSizeTrait, DelimiterLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let n_array = as_int64_array(&args[2])?;
let result = string_array
.iter()
.zip(delimiter_array.iter())
.zip(n_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let split_string: Vec<&str> = string.split(delimiter).collect();
let len = split_string.len();

let index = match n.cmp(&0) {
std::cmp::Ordering::Less => len as i64 + n,
std::cmp::Ordering::Equal => {
return exec_err!("field position must not be zero");
}
std::cmp::Ordering::Greater => n - 1,
} as usize;

if index < len {
Ok(Some(split_string[index]))
} else {
Ok(Some(""))
match (args[0].data_type(), args[1].data_type()) {
(DataType::Utf8View, _) => {
let string_array = as_string_view_array(&args[0])?;
match args[1].data_type() {
DataType::Utf8View => {
let delimiter_array = as_string_view_array(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
_ => {
let delimiter_array =
as_generic_string_array::<DelimiterLen>(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>>>()?;

Ok(Arc::new(result) as ArrayRef)
}
(_, DataType::Utf8View) => {
let delimiter_array = as_string_view_array(&args[1])?;
match args[0].data_type() {
DataType::Utf8View => {
let string_array = as_string_view_array(&args[0])?;
process_split_part!(string_array, delimiter_array, n_array)
}
_ => {
let string_array = as_generic_string_array::<StringLen>(&args[0])?;
process_split_part!(string_array, delimiter_array, n_array)
}
}
}
(_, _) => {
let string_array = as_generic_string_array::<StringLen>(&args[0])?;
let delimiter_array = as_generic_string_array::<DelimiterLen>(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
}
}

#[cfg(test)]
Expand Down
32 changes: 32 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,38 @@ SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2)
----
bar

# test largeutf8, utf8view for split_part
query T
SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3)
----
large

query T
SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3);
----
view

query T
SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2)
----
_split

query T
SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2)
----
large

query T
SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2)
----
_banana

query T
SELECT split_part(NULL, '_', 2)
----
NULL


query B
SELECT starts_with('foobar', 'foo')
----
Expand Down
5 changes: 3 additions & 2 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,12 @@ logical_plan
## TODO file ticket
query TT
EXPLAIN SELECT
SPLIT_PART(column1_utf8view, 'f', 1) as c
SPLIT_PART(column1_utf8view, 'f', 1) as c1,
SPLIT_PART('testtesttest',column1_utf8view, 1) as c2
FROM test;
----
logical_plan
01)Projection: split_part(CAST(test.column1_utf8view AS Utf8), Utf8("f"), Int64(1)) AS c
01)Projection: split_part(test.column1_utf8view, Utf8("f"), Int64(1)) AS c1, split_part(Utf8("testtesttest"), test.column1_utf8view, Int64(1)) AS c2
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for STRPOS
Expand Down

0 comments on commit 8af587e

Please sign in to comment.