Skip to content

Commit

Permalink
feat: support Utf8View type in starts_with function (#11787)
Browse files Browse the repository at this point in the history
* feat: support `Utf8View` for `starts_with`

* style: clippy

* simplify string view handling

* fix: allow utf8 and largeutf8 to be cast into utf8view

* fix: fix test

* Apply suggestions from code review

Co-authored-by: Yongting You <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>

* style: fix format

* feat: add addiontal tests

* tests: improve tests

* fix: fix null case

* tests: one more null test

* Test comments and execution tests

---------

Co-authored-by: Yongting You <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people committed Aug 6, 2024
1 parent 1c98e6e commit 3d76aa2
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 21 deletions.
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
Expand Down
16 changes: 16 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,10 @@ fn coerced_from<'a>(
(Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => {
Some(type_into.clone())
}
// We can go into a Utf8View from a Utf8 or LargeUtf8
(Utf8View, _) if matches!(type_from, Utf8 | LargeUtf8 | Null) => {
Some(type_into.clone())
}
// Any type can be coerced into strings
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
(Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
Expand Down Expand Up @@ -646,6 +650,18 @@ mod tests {
use super::*;
use arrow::datatypes::Field;

#[test]
fn test_string_conversion() {
let cases = vec![
(DataType::Utf8View, DataType::Utf8, true),
(DataType::Utf8View, DataType::LargeUtf8, true),
];

for case in cases {
assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
}
}

#[test]
fn test_maybe_data_types() {
// this vec contains: arg1, arg2, expected result
Expand Down
92 changes: 72 additions & 20 deletions datafusion/functions/src/string/starts_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion_common::{cast::as_generic_string_array, internal_err, Result};
use datafusion_common::{internal_err, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
Expand All @@ -30,12 +30,8 @@ use crate::utils::make_scalar_function;

/// Returns true if string starts with prefix.
/// starts_with('alphabet', 'alph') = 't'
pub fn starts_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let left = as_generic_string_array::<T>(&args[0])?;
let right = as_generic_string_array::<T>(&args[1])?;

let result = arrow::compute::kernels::comparison::starts_with(left, right)?;

pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
Ok(Arc::new(result) as ArrayRef)
}

Expand All @@ -52,14 +48,15 @@ impl Default for StartsWithFunc {

impl StartsWithFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
// Planner attempts coercion to the target type starting with the most preferred candidate.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
// If that fails, it proceeds to `(Utf8, Utf8)`.
Exact(vec![DataType::Utf8View, DataType::Utf8View]),
Exact(vec![DataType::Utf8, DataType::Utf8]),
Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
Volatility::Immutable,
),
Expand All @@ -81,18 +78,73 @@ impl ScalarUDFImpl for StartsWithFunc {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(Boolean)
Ok(DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(starts_with::<i32>, vec![])(args),
DataType::LargeUtf8 => {
return make_scalar_function(starts_with::<i64>, vec![])(args);
DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
make_scalar_function(starts_with, vec![])(args)
}
_ => internal_err!("Unsupported data type"),
_ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?,
}
}
}

#[cfg(test)]
mod tests {
use crate::utils::test::test_function;
use arrow::array::{Array, BooleanArray};
use arrow::datatypes::DataType::Boolean;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

use super::*;

#[test]
fn test_functions() -> Result<()> {
// Generate test cases for starts_with
let test_cases = vec![
(Some("alphabet"), Some("alph"), Some(true)),
(Some("alphabet"), Some("bet"), Some(false)),
(
Some("somewhat large string"),
Some("somewhat large"),
Some(true),
),
(Some("somewhat large string"), Some("large"), Some(false)),
]
.into_iter()
.flat_map(|(a, b, c)| {
let utf_8_args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
];

let large_utf_8_args = vec![
ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
];

let utf_8_view_args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
];

vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
});

for (args, expected) in test_cases {
test_function!(
StartsWithFunc::new(),
&args,
Ok(expected),
bool,
Boolean,
BooleanArray
);
}

Ok(())
}
}
70 changes: 69 additions & 1 deletion datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,75 @@ logical_plan
01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.column1_utf8), count(DISTINCT test.column1_utf8view), count(DISTINCT test.column1_dict)]]
02)--TableScan: test projection=[column1_utf8, column1_utf8view, column1_dict]

### `STARTS_WITH`

# Test STARTS_WITH with utf8view against utf8view, utf8, and largeutf8
# (should be no casts)
query TT
EXPLAIN SELECT
STARTS_WITH(column1_utf8view, column2_utf8view) as c1,
STARTS_WITH(column1_utf8view, column2_utf8) as c2,
STARTS_WITH(column1_utf8view, column2_large_utf8) as c3
FROM test;
----
logical_plan
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view]

query BBB
SELECT
STARTS_WITH(column1_utf8view, column2_utf8view) as c1,
STARTS_WITH(column1_utf8view, column2_utf8) as c2,
STARTS_WITH(column1_utf8view, column2_large_utf8) as c3
FROM test;
----
false false false
true true true
true true true
NULL NULL NULL

# Test STARTS_WITH with utf8 against utf8view, utf8, and largeutf8
# Should work, but will have to cast to common types
# should cast utf8 -> utf8view and largeutf8 -> utf8view
query TT
EXPLAIN SELECT
STARTS_WITH(column1_utf8, column2_utf8view) as c1,
STARTS_WITH(column1_utf8, column2_utf8) as c3,
STARTS_WITH(column1_utf8, column2_large_utf8) as c4
FROM test;
----
logical_plan
01)Projection: starts_with(__common_expr_1, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(__common_expr_1, CAST(test.column2_large_utf8 AS Utf8View)) AS c4
02)--Projection: CAST(test.column1_utf8 AS Utf8View) AS __common_expr_1, test.column1_utf8, test.column2_utf8, test.column2_large_utf8, test.column2_utf8view
03)----TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view]

query BBB
SELECT
STARTS_WITH(column1_utf8, column2_utf8view) as c1,
STARTS_WITH(column1_utf8, column2_utf8) as c3,
STARTS_WITH(column1_utf8, column2_large_utf8) as c4
FROM test;
----
false false false
true true true
true true true
NULL NULL NULL


# Test STARTS_WITH with utf8view against literals
# In this case, the literals should be cast to utf8view. The columns
# should not be cast to utf8.
query TT
EXPLAIN SELECT
STARTS_WITH(column1_utf8view, 'äöüß') as c1,
STARTS_WITH(column1_utf8view, '') as c2,
STARTS_WITH(column1_utf8view, NULL) as c3,
STARTS_WITH(NULL, column1_utf8view) as c4
FROM test;
----
logical_plan
01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4
02)--TableScan: test projection=[column1_utf8view]

statement ok
drop table test;
Expand All @@ -376,6 +445,5 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt;
----
2024-01-23


statement ok
drop table dates;

0 comments on commit 3d76aa2

Please sign in to comment.