Skip to content

Commit

Permalink
Support map_keys & map_values for MAP type (#12194)
Browse files Browse the repository at this point in the history
* impl map_keys

* rename field name

* add logic tests

* one more

* owned to clone

* more tests

* typo

* impl

* add logic tests

* chore

* add docs

* trying to make prettier happy

* Update scalar_functions.md

Co-authored-by: Alex Huang <[email protected]>

* reface signature

* format docs

* Update map_values.rs

Co-authored-by: Alex Huang <[email protected]>

---------

Co-authored-by: Alex Huang <[email protected]>
  • Loading branch information
dharanad and Weijun-H committed Sep 1, 2024
1 parent 016ed03 commit 8746e07
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 21 deletions.
17 changes: 1 addition & 16 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use arrow_array::{
Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait,
RecordBatchOptions,
};
use arrow_schema::{DataType, Fields};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -754,21 +754,6 @@ pub fn combine_limit(
(combined_skip, combined_fetch)
}

pub fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
match data_type {
DataType::Map(field, _) => {
let field_data_type = field.data_type();
match field_data_type {
DataType::Struct(fields) => Ok(fields),
_ => {
_internal_err!("Expected a Struct type, got {:?}", field_data_type)
}
}
}
_ => _internal_err!("Expected a Map type, got {:?}", data_type),
}
}

#[cfg(test)]
mod tests {
use crate::ScalarValue::Null;
Expand Down
6 changes: 6 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub mod length;
pub mod make_array;
pub mod map;
pub mod map_extract;
pub mod map_keys;
pub mod map_values;
pub mod planner;
pub mod position;
pub mod range;
Expand Down Expand Up @@ -85,6 +87,8 @@ pub mod expr_fn {
pub use super::length::array_length;
pub use super::make_array::make_array;
pub use super::map_extract::map_extract;
pub use super::map_keys::map_keys;
pub use super::map_values::map_values;
pub use super::position::array_position;
pub use super::position::array_positions;
pub use super::range::gen_series;
Expand Down Expand Up @@ -149,6 +153,8 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
replace::array_replace_udf(),
map::map_udf(),
map_extract::map_extract_udf(),
map_keys::map_keys_udf(),
map_values::map_values_udf(),
]
}

Expand Down
3 changes: 1 addition & 2 deletions datafusion/functions-nested/src/map_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ use arrow::datatypes::DataType;
use arrow_array::{Array, MapArray};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;
use datafusion_common::utils::get_map_entry_field;

use datafusion_common::{cast::as_map_array, exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
use std::vec;

use crate::utils::make_scalar_function;
use crate::utils::{get_map_entry_field, make_scalar_function};

// Create static instances of ScalarUDFs for each function
make_udf_expr_and_func!(
Expand Down
102 changes: 102 additions & 0 deletions datafusion/functions-nested/src/map_keys.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for map_keys function.

use crate::utils::{get_map_entry_field, make_scalar_function};
use arrow_array::{Array, ArrayRef, ListArray};
use arrow_schema::{DataType, Field};
use datafusion_common::{cast::as_map_array, exec_err, Result};
use datafusion_expr::{
ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use std::any::Any;
use std::sync::Arc;

make_udf_expr_and_func!(
MapKeysFunc,
map_keys,
map,
"Return a list of all keys in the map.",
map_keys_udf
);

#[derive(Debug)]
pub(crate) struct MapKeysFunc {
signature: Signature,
}

impl MapKeysFunc {
pub fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for MapKeysFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"map_keys"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
if arg_types.len() != 1 {
return exec_err!("map_keys expects single argument");
}
let map_type = &arg_types[0];
let map_fields = get_map_entry_field(map_type)?;
Ok(DataType::List(Arc::new(Field::new(
"item",
map_fields.first().unwrap().data_type().clone(),
false,
))))
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function(map_keys_inner)(args)
}
}

fn map_keys_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("map_keys expects single argument");
}

let map_array = match args[0].data_type() {
DataType::Map(_, _) => as_map_array(&args[0])?,
_ => return exec_err!("Argument for map_keys should be a map"),
};

Ok(Arc::new(ListArray::new(
Arc::new(Field::new("item", map_array.key_type().clone(), false)),
map_array.offsets().clone(),
Arc::clone(map_array.keys()),
None,
)))
}
102 changes: 102 additions & 0 deletions datafusion/functions-nested/src/map_values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for map_values function.

use crate::utils::{get_map_entry_field, make_scalar_function};
use arrow_array::{Array, ArrayRef, ListArray};
use arrow_schema::{DataType, Field};
use datafusion_common::{cast::as_map_array, exec_err, Result};
use datafusion_expr::{
ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use std::any::Any;
use std::sync::Arc;

make_udf_expr_and_func!(
MapValuesFunc,
map_values,
map,
"Return a list of all values in the map.",
map_values_udf
);

#[derive(Debug)]
pub(crate) struct MapValuesFunc {
signature: Signature,
}

impl MapValuesFunc {
pub fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for MapValuesFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"map_values"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
if arg_types.len() != 1 {
return exec_err!("map_values expects single argument");
}
let map_type = &arg_types[0];
let map_fields = get_map_entry_field(map_type)?;
Ok(DataType::List(Arc::new(Field::new(
"item",
map_fields.last().unwrap().data_type().clone(),
true,
))))
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function(map_values_inner)(args)
}
}

fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("map_values expects single argument");
}

let map_array = match args[0].data_type() {
DataType::Map(_, _) => as_map_array(&args[0])?,
_ => return exec_err!("Argument for map_values should be a map"),
};

Ok(Arc::new(ListArray::new(
Arc::new(Field::new("item", map_array.value_type().clone(), true)),
map_array.offsets().clone(),
Arc::clone(map_array.values()),
None,
)))
}
19 changes: 17 additions & 2 deletions datafusion/functions-nested/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ use arrow_array::{
UInt32Array,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;
use arrow_schema::{Field, Fields};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};

use core::any::type_name;
use datafusion_common::DataFusionError;
Expand Down Expand Up @@ -253,6 +253,21 @@ pub(crate) fn compute_array_dims(
}
}

pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
match data_type {
DataType::Map(field, _) => {
let field_data_type = field.data_type();
match field_data_type {
DataType::Struct(fields) => Ok(fields),
_ => {
internal_err!("Expected a Struct type, got {:?}", field_data_type)
}
}
}
_ => internal_err!("Expected a Map type, got {:?}", data_type),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 8746e07

Please sign in to comment.