Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access a Map with Primitive type keys #12259

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
13 changes: 12 additions & 1 deletion datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_expr::{
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
sqlparser, Expr, ExprSchemable, GetFieldAccess,
};
use datafusion_functions::expr_fn::get_field;
use datafusion_functions::expr_fn::{_get_field, get_field};
use datafusion_functions_aggregate::nth_value::nth_value_udaf;

use crate::map::map_udf;
Expand Down Expand Up @@ -148,6 +148,13 @@ impl ExprPlanner for FieldAccessPlanner {
// expr[idx] ==> array_element(expr, idx)
GetFieldAccess::ListIndex { key: index } => {
match expr {
// Special case for accessing map value with non-string values
Expr::ScalarFunction(scalar_func) if is_map(&scalar_func) => {
Ok(PlannerResult::Planned(_get_field(
Expr::ScalarFunction(scalar_func),
*index,
dharanad marked this conversation as resolved.
Show resolved Hide resolved
)))
}
// Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => {
Ok(PlannerResult::Planned(Expr::AggregateFunction(
Expand Down Expand Up @@ -186,3 +193,7 @@ impl ExprPlanner for FieldAccessPlanner {
fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
return agg_func.func.name() == "array_agg";
}

fn is_map(scalar_func: &datafusion_expr::expr::ScalarFunction) -> bool {
return scalar_func.func.name() == "map";
}
36 changes: 28 additions & 8 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// under the License.

use arrow::array::{
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
make_array, Array, ArrayRef, Capacities, Float64Array, Int64Array, MutableArrayData,
Scalar, StringArray,
};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_map_array, as_struct_array};
Expand Down Expand Up @@ -47,7 +48,6 @@ impl GetFieldFunc {
}
}

// get_field(struct_array, field_name)
impl ScalarUDFImpl for GetFieldFunc {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -184,9 +184,29 @@ impl ScalarUDFImpl for GetFieldFunc {
};

match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
(DataType::Map(_, _), name) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()]));
if !matches!(name, ScalarValue::Utf8(_) | ScalarValue::Int64(_) | ScalarValue::Float64(_)) {
return exec_err!(
"get indexed field is only possible on map with utf8, int64 and float64 indexes. \
Tried with {name:?} index")
}
let key_array: ArrayRef = match name {
dharanad marked this conversation as resolved.
Show resolved Hide resolved
ScalarValue::Int64(Some(k)) => {
Arc::new(Int64Array::from(vec![*k]))
}
ScalarValue::Utf8(Some(k)) => {
Arc::new(StringArray::from(vec![k.clone()]))
}
ScalarValue::Float64(Some(k)) => {
Arc::new(Float64Array::from(vec![*k]))
}
_ => {
unreachable!();
dharanad marked this conversation as resolved.
Show resolved Hide resolved
}
};

let key_scalar = Scalar::new(key_array);
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;

// note that this array has more entries than the expected output/input size
Expand All @@ -195,16 +215,16 @@ impl ScalarUDFImpl for GetFieldFunc {
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true,
capacity);
capacity);

for entry in 0..map_array.len(){
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;

let maybe_matched =
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
if maybe_matched.is_none(){
mutable.extend_nulls(1);
continue
Expand Down
6 changes: 6 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ pub mod expr_fn {
pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr {
super::get_field().call(vec![arg1, arg2.lit()])
}

/// Returns the value of the field with the given name from the struct.
/// **Internal use only.** This function is added to support the map use case.
pub fn _get_field(arg1: Expr, arg2: Expr) -> Expr {
super::get_field().call(vec![arg1, arg2])
}
dharanad marked this conversation as resolved.
Show resolved Hide resolved
}

/// Returns all DataFusion functions defined in this package
Expand Down
27 changes: 16 additions & 11 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -482,23 +482,28 @@ SELECT MAP { 'a': 1, 2: 3 };
----
{a: 1, 2: 3}

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1];
# ----
# a

query T
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1];
----
a

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }[MAP {1:'a', 2:'b'}];
# ----
# 1

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2];
# ----
# 33
query I
SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2];
----
33

query T
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2.0];
----
b

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2];?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implmentation will throw an error. Just checked DuckDB, it try to cast the index to keys data type.


## cardinality

Expand Down Expand Up @@ -572,4 +577,4 @@ statement ok
drop table map_array_table_1;

statement ok
drop table map_array_table_2;
drop table map_array_table_2;