Skip to content

Commit

Permalink
Implement native support StringView for overlay (apache#11968)
Browse files Browse the repository at this point in the history
* Implement native support StringView for overlay

Signed-off-by: Chojan Shang <[email protected]>

* Re-write impl of overlay

Signed-off-by: Chojan Shang <[email protected]>

* Minor update

Signed-off-by: Chojan Shang <[email protected]>

* Add more tests

Signed-off-by: Chojan Shang <[email protected]>

---------

Signed-off-by: Chojan Shang <[email protected]>
  • Loading branch information
PsiACE authored and samuelcolvin committed Aug 15, 2024
1 parent 65159e1 commit dff44ea
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 67 deletions.
182 changes: 118 additions & 64 deletions datafusion/functions/src/string/overlay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
Expand All @@ -46,8 +48,10 @@ impl OverlayFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View, Int64, Int64]),
Exact(vec![Utf8, Utf8, Int64, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
Exact(vec![Utf8View, Utf8View, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
Expand Down Expand Up @@ -76,54 +80,107 @@ impl ScalarUDFImpl for OverlayFunc {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(overlay::<i32>, vec![])(args),
DataType::Utf8View | DataType::Utf8 => {
make_scalar_function(overlay::<i32>, vec![])(args)
}
DataType::LargeUtf8 => make_scalar_function(overlay::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function overlay"),
}
}
}

macro_rules! process_overlay {
// For the three-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
}};

// For the four-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.zip($len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
}};
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let use_string_view = args[0].data_type() == &DataType::Utf8View;
if use_string_view {
string_view_overlay::<T>(args)
} else {
string_overlay::<T>(args)
}
}

pub fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
Expand All @@ -132,37 +189,34 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.zip(len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
}

pub fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
Expand Down
27 changes: 26 additions & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM
1003 OldBrand Product 3 79.98
1004 OldBrand Product 4 99.98

#overlay tests
# overlay tests
statement ok
CREATE TABLE over_test(
str TEXT,
Expand Down Expand Up @@ -967,6 +967,31 @@ NULL
Thomxas
NULL

# overlay tests with utf8view
query T
SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos for len) from over_test
----
abc
qwertyasdfg
ijkz
Thomas
NULL
NULL
NULL
NULL

query T
SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos) from over_test
----
abc
qwertyasdfg
ijk
Thomxas
NULL
NULL
Thomxas
NULL

query I
SELECT levenshtein('kitten', 'sitting')
----
Expand Down
11 changes: 9 additions & 2 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -818,16 +818,23 @@ logical_plan
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for OVERLAY
## TODO file ticket
query TT
EXPLAIN SELECT
OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1
FROM test;
----
logical_plan
01)Projection: overlay(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Int64(2)) AS c1
01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1
02)--TableScan: test projection=[column1_utf8view]

query T
SELECT OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1 FROM test;
----
Afooew
Xfoogpeng
Rfooael
NULL

## Ensure no casts for REGEXP_LIKE
query TT
EXPLAIN SELECT
Expand Down

0 comments on commit dff44ea

Please sign in to comment.