From 72b6a49e7420fe2750bda88fabea90995079de9f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 17 Aug 2024 17:46:09 +0800 Subject: [PATCH] feat: Add map_extract module and function (#11969) * feat: Add map_extract module and function * chore: Fix fmt * chore: Add tests * chore: Simplify * chore: Simplify * chore: Fix clippy * doc: Add user doc * feat: use Signature::user_defined * chore: Update tests * chore: Fix fmt * chore: Fix clippy * chore * chore: typo * chore: Check args len in return_type * doc: Update doc * chore: Simplify logic * chore: check args earlier * feat: Support UTF8VIEW * chore: Update doc * chore: Fic clippy * refacotr: Use MutableArrayData * chore * refactor: Avoid type conversion * chore: Fix clippy * chore: Follow DuckDB * Update datafusion/functions-nested/src/map_extract.rs Co-authored-by: Jay Zhan * chore: Fix fmt --------- Co-authored-by: Jay Zhan --- datafusion/common/src/utils/mod.rs | 17 +- datafusion/functions-nested/src/lib.rs | 3 + .../functions-nested/src/map_extract.rs | 173 ++++++++++++++++++ datafusion/sqllogictest/test_files/map.slt | 81 ++++++++ .../source/user-guide/sql/scalar_functions.md | 29 +++ 5 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions-nested/src/map_extract.rs diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index bf506c0551eb..d7059e882e55 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -34,7 +34,7 @@ use arrow_array::{ Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, RecordBatchOptions, }; -use arrow_schema::DataType; +use arrow_schema::{DataType, Fields}; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -753,6 +753,21 @@ pub fn combine_limit( (combined_skip, combined_fetch) } +pub fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { + match data_type { + DataType::Map(field, _) => { + let field_data_type = field.data_type(); + match field_data_type { + DataType::Struct(fields) => Ok(fields), + _ => { + _internal_err!("Expected a Struct type, got {:?}", field_data_type) + } + } + } + _ => _internal_err!("Expected a Map type, got {:?}", data_type), + } +} + #[cfg(test)] mod tests { use crate::ScalarValue::Null; diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index ef2c5e709bc1..cc0a7b55cf86 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -42,6 +42,7 @@ pub mod flatten; pub mod length; pub mod make_array; pub mod map; +pub mod map_extract; pub mod planner; pub mod position; pub mod range; @@ -81,6 +82,7 @@ pub mod expr_fn { pub use super::flatten::flatten; pub use super::length::array_length; pub use super::make_array::make_array; + pub use super::map_extract::map_extract; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -143,6 +145,7 @@ pub fn all_default_nested_functions() -> Vec> { replace::array_replace_all_udf(), replace::array_replace_udf(), map::map_udf(), + map_extract::map_extract_udf(), ] } diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs new file mode 100644 index 000000000000..82f0d8d6c15e --- /dev/null +++ b/datafusion/functions-nested/src/map_extract.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_extract functions. + +use arrow::array::{ArrayRef, Capacities, MutableArrayData}; +use arrow_array::{make_array, ListArray}; + +use arrow::datatypes::DataType; +use arrow_array::{Array, MapArray}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::Field; +use datafusion_common::utils::get_map_entry_field; + +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; +use std::vec; + +use crate::utils::make_scalar_function; + +// Create static instances of ScalarUDFs for each function +make_udf_expr_and_func!( + MapExtract, + map_extract, + map key, + "Return a list containing the value for a given key or an empty list if the key is not contained in the map.", + map_extract_udf +); + +#[derive(Debug)] +pub(super) struct MapExtract { + signature: Signature, + aliases: Vec, +} + +impl MapExtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![String::from("element_at")], + } + } +} + +impl ScalarUDFImpl for MapExtract { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "map_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.last().unwrap().data_type().clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(map_extract_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + + let field = get_map_entry_field(&arg_types[0])?; + Ok(vec![ + arg_types[0].clone(), + field.first().unwrap().data_type().clone(), + ]) + } +} + +fn general_map_extract_inner( + map_array: &MapArray, + query_keys_array: &dyn Array, +) -> Result { + let keys = map_array.keys(); + let mut offsets = vec![0_i32]; + + let values = map_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + let query_key = query_keys_array.slice(row_index, 1); + + let value_index = + (0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref()); + + match value_index { + Some(index) => { + mutable.extend(0, start + index, start + index + 1); + } + None => { + mutable.extend_nulls(1); + } + } + offsets.push(offsets[row_index] + 1); + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.value_type().clone(), true)), + OffsetBuffer::::new(offsets.into()), + Arc::new(make_array(data)), + None, + ))) +} + +fn map_extract_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("The first argument in map_extract must be a map"), + }; + + let key_type = map_array.key_type(); + + if key_type != args[1].data_type() { + return exec_err!( + "The key type {} does not match the map key type {}", + args[1].data_type(), + key_type + ); + } + + general_map_extract_inner(map_array, &args[1]) +} diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 0dc37c68bca4..b7a0a74913b0 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -15,6 +15,22 @@ # specific language governing permissions and limitations # under the License. +statement ok +CREATE TABLE map_array_table_1 +AS VALUES + (MAP {1: [1, NULL, 3], 2: [4, NULL, 6], 3: [7, 8, 9]}, 1, 1.0, '1'), + (MAP {4: [1, NULL, 3], 5: [4, NULL, 6], 6: [7, 8, 9]}, 5, 5.0, '5'), + (MAP {7: [1, NULL, 3], 8: [9, NULL, 6], 9: [7, 8, 9]}, 4, 4.0, '4') +; + +statement ok +CREATE TABLE map_array_table_2 +AS VALUES + (MAP {'1': [1, NULL, 3], '2': [4, NULL, 6], '3': [7, 8, 9]}, 1, 1.0, '1'), + (MAP {'4': [1, NULL, 3], '5': [4, NULL, 6], '6': [7, 8, 9]}, 5, 5.0, '5'), + (MAP {'7': [1, NULL, 3], '8': [9, NULL, 6], '9': [7, 8, 9]}, 4, 4.0, '4') +; + statement ok CREATE EXTERNAL TABLE data STORED AS PARQUET @@ -493,3 +509,68 @@ select cardinality(map([1, 2, 3], ['a', 'b', 'c'])), cardinality(MAP {'a': 1, 'b cardinality(MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }); ---- 3 2 0 2 + +# map_extract +# key is string +query ???? +select map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'b'), + map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'c'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'd'); +---- +[1] [] [3] [] + +# key is integer +query ???? +select map_extract(MAP {1: 1, 2: NULL, 3:3}, 1), map_extract(MAP {1: 1, 2: NULL, 3:3}, 2), + map_extract(MAP {1: 1, 2: NULL, 3:3}, 3), map_extract(MAP {1: 1, 2: NULL, 3:3}, 4); +---- +[1] [] [3] [] + +# value is list +query ???? +select map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 1), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 2), + map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 3), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 4); +---- +[[1, 2]] [] [[3]] [] + +# key in map and query key are different types +query ????? +select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3}, 1.0), + map_extract(MAP {1.0: 1, 2: 2, 3:3}, '1'), map_extract(MAP {'1': 1, '2': 2, '3':3}, 1.0), + map_extract(MAP {arrow_cast('1', 'Utf8View'): 1, arrow_cast('2', 'Utf8View'): 2, arrow_cast('3', 'Utf8View'):3}, '1'); +---- +[1] [1] [1] [] [1] + +# map_extract with columns +query ??? +select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1; +---- +[[1, , 3]] [] [] +[] [[4, , 6]] [] +[] [] [[1, , 3]] + +query ??? +select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; +---- +[[1, , 3]] [[1, , 3]] [[1, , 3]] +[[4, , 6]] [[4, , 6]] [[4, , 6]] +[] [] [] + +query ??? +select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_2; +---- +[[1, , 3]] [] [[1, , 3]] +[[4, , 6]] [] [[4, , 6]] +[] [] [] + +query ??? +select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_2; +---- +[[1, , 3]] [] [] +[] [[4, , 6]] [] +[] [] [[1, , 3]] + +statement ok +drop table map_array_table_1; + +statement ok +drop table map_array_table_2; \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c7490df04983..c7b3409ba7cd 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3640,6 +3640,7 @@ Unwraps struct fields into columns. - [map](#map) - [make_map](#make_map) +- [map_extract](#map_extract) ### `map` @@ -3700,6 +3701,34 @@ SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); {POST: 41, HEAD: 33, PATCH: } ``` +### `map_extract` + +Return a list containing the value for a given key or an empty list if the key is not contained in the map. + +``` +map_extract(map, key) +``` + +#### Arguments + +- `map`: Map expression. + Can be a constant, column, or function, and any combination of map operators. +- `key`: Key to extract from the map. + Can be a constant, column, or function, any combination of arithmetic or + string operators, or a named expression of previous listed. + +#### Example + +``` +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] +``` + +#### Aliases + +- element_at + ## Hashing Functions - [digest](#digest)