Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access a Map with Primitive type keys #12259

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
14 changes: 14 additions & 0 deletions datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ impl ExprPlanner for FieldAccessPlanner {
// expr[idx] ==> array_element(expr, idx)
GetFieldAccess::ListIndex { key: index } => {
match expr {
// Special case for accessing map value with non-string values
Expr::ScalarFunction(scalar_func) if is_map(&scalar_func) => {
let Expr::Literal(name) = *index else {
return plan_err!("index should be a literal");
};
Ok(PlannerResult::Planned(get_field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use map_extract here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map_extract returns a list containing the value. Maybe can wrap it around get_field call and extract the value. But i am not sure, let me try it out.

Expr::ScalarFunction(scalar_func),
name,
)))
}
// Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => {
Ok(PlannerResult::Planned(Expr::AggregateFunction(
Expand Down Expand Up @@ -186,3 +196,7 @@ impl ExprPlanner for FieldAccessPlanner {
fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
return agg_func.func.name() == "array_agg";
}

fn is_map(scalar_func: &datafusion_expr::expr::ScalarFunction) -> bool {
return scalar_func.func.name() == "map";
}
40 changes: 28 additions & 12 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
};
use std::any::Any;
use std::sync::Arc;

use arrow::array::{make_array, Array, ArrayRef, Capacities, MutableArrayData, Scalar};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
use datafusion_common::{
exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct GetFieldFunc {
Expand All @@ -47,7 +48,6 @@ impl GetFieldFunc {
}
}

// get_field(struct_array, field_name)
impl ScalarUDFImpl for GetFieldFunc {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -184,9 +184,25 @@ impl ScalarUDFImpl for GetFieldFunc {
};

match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
(DataType::Map(_, _), name) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()]));
if !matches!(name, ScalarValue::Utf8(_) | ScalarValue::Int64(_) | ScalarValue::Float64(_)) {
return exec_err!(
"get indexed field is only possible on map with utf8, int64 and float64 indexes. \
Tried with {name:?} index")
}

let mut key_array: ArrayRef = name.to_array()?;
if key_array.data_type() != map_array.key_type() {
let pre_cast_dt = key_array.data_type().clone();
if arrow::compute::kernels::cast::can_cast_types(key_array.data_type(), map_array.key_type()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting and type check should be handled in Signature not invoke.

key_array = arrow::compute::kernels::cast::cast_with_options(&key_array, map_array.key_type(), &DEFAULT_CAST_OPTIONS)?;
}
if key_array.null_count() > 0{
return exec_err!("Could not convert {} {} to {}", pre_cast_dt, name, map_array.key_type())
}
}
let key_scalar = Scalar::new(key_array);
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;

// note that this array has more entries than the expected output/input size
Expand All @@ -195,16 +211,16 @@ impl ScalarUDFImpl for GetFieldFunc {
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true,
capacity);
capacity);

for entry in 0..map_array.len(){
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;

let maybe_matched =
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
if maybe_matched.is_none(){
mutable.extend_nulls(1);
continue
Expand Down
54 changes: 43 additions & 11 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -482,23 +482,55 @@ SELECT MAP { 'a': 1, 2: 3 };
----
{a: 1, 2: 3}

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1];
# ----
# a

query T
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1];
----
a

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }[MAP {1:'a', 2:'b'}];
# ----
# 1

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2];
# ----
# 33
query I
SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2];
----
33

query T
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2.0];
----
b

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2];?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implmentation will throw an error. Just checked DuckDB, it try to cast the index to keys data type.

query T
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }['2.0'];
----
b

query T
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2];
----
b

query T
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }['2'];
----
b

query T
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[3];
----
c

query T
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }['3'];
----
c

query error DataFusion error: Execution error: Could not convert Utf8 3\.0 to Int64
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }['3.0'];

## cardinality

Expand Down Expand Up @@ -572,4 +604,4 @@ statement ok
drop table map_array_table_1;

statement ok
drop table map_array_table_2;
drop table map_array_table_2;
Loading