diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index b1cec3bad774..577c663142a1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -145,6 +145,9 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// Specialized Signature for MapArray + /// The function takes a single argument that must be a MapArray + MapArray, } impl std::fmt::Display for ArrayFunctionSignature { @@ -165,6 +168,9 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::MapArray => { + write!(f, "map_array") + } } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index ef52a01e0598..66807c3f446c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -378,6 +378,16 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::MapArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + + match ¤t_types[0] { + DataType::Map(_, _) => vec![vec![current_types[0].clone()]], + _ => vec![vec![]], + } + } }, TypeSignature::Any(number) => { if current_types.len() != *number { diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index f6755c344768..ea07ac381aff 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -18,13 +18,18 @@ //! [`ScalarUDFImpl`] definitions for cardinality function. use crate::utils::make_scalar_function; -use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array}; +use arrow_array::{ + Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, UInt64Array, +}; use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; -use datafusion_common::cast::{as_large_list_array, as_list_array}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64}; +use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; use std::any::Any; use std::sync::Arc; @@ -32,14 +37,20 @@ make_udf_expr_and_func!( Cardinality, cardinality, array, - "returns the total number of elements in the array.", + "returns the total number of elements in the array or map.", cardinality_udf ); impl Cardinality { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), aliases: vec![], } } @@ -64,9 +75,9 @@ impl ScalarUDFImpl for Cardinality { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64, _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); + return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map."); } }) } @@ -95,12 +106,24 @@ pub fn cardinality_inner(args: &[ArrayRef]) -> Result { let list_array = as_large_list_array(&args[0])?; generic_list_cardinality::(list_array) } + Map(_, _) => { + let map_array = as_map_array(&args[0])?; + generic_map_cardinality(map_array) + } other => { exec_err!("cardinality does not support type '{:?}'", other) } } } +fn generic_map_cardinality(array: &MapArray) -> Result { + let result: UInt64Array = array + .iter() + .map(|opt_arr| opt_arr.map(|arr| arr.len() as u64)) + .collect(); + Ok(Arc::new(result)) +} + fn generic_list_cardinality( array: &GenericListArray, ) -> Result { diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 11998eea9044..eb350c22bb5d 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -459,3 +459,12 @@ SELECT MAP { 'a': 1, 2: 3 }; # SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2]; # ---- # 33 + +## cardinality + +# cardinality scalar function +query IIII +select cardinality(map([1, 2, 3], ['a', 'b', 'c'])), cardinality(MAP {'a': 1, 'b': null}), cardinality(MAP([],[])), + cardinality(MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }); +---- +3 2 0 2 diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 60036e440ffb..ad5a9cb75152 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -243,7 +243,7 @@ select log(-1), log(0), sqrt(-1); | array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2]` | | array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 0, 0]` | | array_sort(array, desc, null_first) | Returns sorted array. `array_sort([3, 1, 2, 5, 4]) -> [1, 2, 3, 4, 5]` | -| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | +| cardinality(array/map) | Returns the total number of elements in the array or map. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | | string_to_array(array, delimiter, null_string) | Splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`. `string_to_array('abc#def#ghi', '#', ' ') -> ['abc', 'def', 'ghi']` |