Skip to content

Commit

Permalink
Add array_distance function (#12211)
Browse files Browse the repository at this point in the history
* Add `distance` aggregation function

Signed-off-by: Austin Liu <[email protected]>

Add `distance` aggregation function

Signed-off-by: Austin Liu <[email protected]>

* Add sql logic test for `distance`

Signed-off-by: Austin Liu <[email protected]>

* Simplify diff calculation

Signed-off-by: Austin Liu <[email protected]>

* Add `array_distance`/`list_distance` as list function in functions-nested

Signed-off-by: Austin Liu <[email protected]>

* Remove aggregate function `distance`

Signed-off-by: Austin Liu <[email protected]>

* format

Signed-off-by: Austin Liu <[email protected]>

* clean up error handling

Signed-off-by: Austin Liu <[email protected]>

* Add `array_distance` in scalar array functions docs

Signed-off-by: Austin Liu <[email protected]>

* Update bulletin

Signed-off-by: Austin Liu <[email protected]>

* Prettify example

Signed-off-by: Austin Liu <[email protected]>

---------

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Aug 29, 2024
1 parent 1fce2a9 commit bd50698
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 0 deletions.
215 changes: 215 additions & 0 deletions datafusion/functions-nested/src/distance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [ScalarUDFImpl] definitions for array_distance function.

use crate::utils::{downcast_arg, make_scalar_function};
use arrow_array::{
Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
};
use arrow_schema::DataType;
use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List};
use core::any::type_name;
use datafusion_common::cast::{
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
as_int64_array,
};
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

make_udf_expr_and_func!(
ArrayDistance,
array_distance,
array,
"returns the Euclidean distance between two numeric arrays.",
array_distance_udf
);

#[derive(Debug)]
pub(super) struct ArrayDistance {
signature: Signature,
aliases: Vec<String>,
}

impl ArrayDistance {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec!["list_distance".to_string()],
}
}
}

impl ScalarUDFImpl for ArrayDistance {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"array_distance"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64),
_ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_distance_inner)(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}

match (&args[0].data_type(), &args[1].data_type()) {
(List(_), List(_)) => general_array_distance::<i32>(args),
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
(array_type1, array_type2) => {
exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'")
}
}
}

fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;

let result = list_array1
.iter()
.zip(list_array2.iter())
.map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
.collect::<Result<Float64Array>>()?;

Ok(Arc::new(result) as ArrayRef)
}

/// Computes the Euclidean distance between two arrays
fn compute_array_distance(
arr1: Option<ArrayRef>,
arr2: Option<ArrayRef>,
) -> Result<Option<f64>> {
let value1 = match arr1 {
Some(arr) => arr,
None => return Ok(None),
};
let value2 = match arr2 {
Some(arr) => arr,
None => return Ok(None),
};

let mut value1 = value1;
let mut value2 = value2;

loop {
match value1.data_type() {
List(_) => {
if downcast_arg!(value1, ListArray).null_count() > 0 {
return Ok(None);
}
value1 = downcast_arg!(value1, ListArray).value(0);
}
LargeList(_) => {
if downcast_arg!(value1, LargeListArray).null_count() > 0 {
return Ok(None);
}
value1 = downcast_arg!(value1, LargeListArray).value(0);
}
_ => break,
}

match value2.data_type() {
List(_) => {
if downcast_arg!(value2, ListArray).null_count() > 0 {
return Ok(None);
}
value2 = downcast_arg!(value2, ListArray).value(0);
}
LargeList(_) => {
if downcast_arg!(value2, LargeListArray).null_count() > 0 {
return Ok(None);
}
value2 = downcast_arg!(value2, LargeListArray).value(0);
}
_ => break,
}
}

// Check for NULL values inside the arrays
if value1.null_count() != 0 || value2.null_count() != 0 {
return Ok(None);
}

let values1 = convert_to_f64_array(&value1)?;
let values2 = convert_to_f64_array(&value2)?;

if values1.len() != values2.len() {
return exec_err!("Both arrays must have the same length");
}

let sum_squares: f64 = values1
.iter()
.zip(values2.iter())
.map(|(v1, v2)| {
let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
diff * diff
})
.sum();

Ok(Some(sum_squares.sqrt()))
}

/// Converts an array of any numeric type to a Float64Array.
fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
match array.data_type() {
DataType::Float64 => Ok(as_float64_array(array)?.clone()),
DataType::Float32 => {
let array = as_float32_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
DataType::Int64 => {
let array = as_int64_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
DataType::Int32 => {
let array = as_int32_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
_ => exec_err!("Unsupported array type for conversion to Float64Array"),
}
}
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod array_has;
pub mod cardinality;
pub mod concat;
pub mod dimension;
pub mod distance;
pub mod empty;
pub mod except;
pub mod expr_ext;
Expand Down Expand Up @@ -73,6 +74,7 @@ pub mod expr_fn {
pub use super::concat::array_prepend;
pub use super::dimension::array_dims;
pub use super::dimension::array_ndims;
pub use super::distance::array_distance;
pub use super::empty::array_empty;
pub use super::except::array_except;
pub use super::extract::array_element;
Expand Down Expand Up @@ -128,6 +130,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
array_has::array_has_any_udf(),
empty::array_empty_udf(),
length::array_length_udf(),
distance::array_distance_udf(),
flatten::flatten_udf(),
sort::array_sort_udf(),
repeat::array_repeat_udf(),
Expand Down
54 changes: 54 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4715,6 +4715,60 @@ NULL 10
NULL 10
NULL 10

query RRR
select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]);
----
1 1 3

query error
select list_distance([1], [1, 2]);

query R
select array_distance([[1, 1]], [1, 2]);
----
1

query R
select array_distance([[1, 1]], [[1, 2]]);
----
1

query R
select array_distance([[1, 1]], [[1, 2]]);
----
1

query RR
select array_distance([1, 1, 0, 0], [2, 2, 1, 1]), list_distance([1, 2, 3], [1, 2, 3]);
----
2 0

query RR
select array_distance([1.0, 1, 0, 0], [2, 2.0, 1, 1]), list_distance([1, 2.0, 3], [1, 2, 3]);
----
2 0

query R
select list_distance([1, 1, NULL, 0], [2, 2, NULL, NULL]);
----
NULL

query R
select list_distance([NULL, NULL], [NULL, NULL]);
----
NULL

query R
select list_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.5]) AS distance;
----
0.5

query R
select list_distance([1, 2, 3], [1, 2, 3]) AS distance;
----
0


## array_dims (aliases: `list_dims`)

# array dims error
Expand Down
36 changes: 36 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,7 @@ to_unixtime(expression[, ..., format_n])
- [array_concat](#array_concat)
- [array_contains](#array_contains)
- [array_dims](#array_dims)
- [array_distance](#array_distance)
- [array_distinct](#array_distinct)
- [array_has](#array_has)
- [array_has_all](#array_has_all)
Expand Down Expand Up @@ -2135,6 +2136,7 @@ to_unixtime(expression[, ..., format_n])
- [list_cat](#list_cat)
- [list_concat](#list_concat)
- [list_dims](#list_dims)
- [list_distance](#list_distance)
- [list_distinct](#list_distinct)
- [list_element](#list_element)
- [list_except](#list_except)
Expand Down Expand Up @@ -2388,6 +2390,36 @@ array_dims(array)

- list_dims

### `array_distance`

Returns the Euclidean distance between two input arrays of equal length.

```
array_distance(array1, array2)
```

#### Arguments

- **array1**: Array expression.
Can be a constant, column, or function, and any combination of array operators.
- **array2**: Array expression.
Can be a constant, column, or function, and any combination of array operators.

#### Example

```
> select array_distance([1, 2], [1, 4]);
+------------------------------------+
| array_distance(List([1,2], [1,4])) |
+------------------------------------+
| 2.0 |
+------------------------------------+
```

#### Aliases

- list_distance

### `array_distinct`

Returns distinct values from the array after removing duplicates.
Expand Down Expand Up @@ -3224,6 +3256,10 @@ _Alias of [array_concat](#array_concat)._

_Alias of [array_dims](#array_dims)._

### `list_distance`

_Alias of [array_distance](#array_distance)._

### `list_distinct`

_Alias of [array_dims](#array_distinct)._
Expand Down

0 comments on commit bd50698

Please sign in to comment.