-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Access a Map
with Primitive type keys
#12259
base: main
Are you sure you want to change the base?
Changes from all commits
72119dd
ebf663e
95b1868
afd7f11
4f850c6
c32fedb
0b342a4
a05143a
494c384
1b19eb9
41b2b59
568223e
28a7dff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,18 +15,19 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::array::{ | ||
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, | ||
}; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::{make_array, Array, ArrayRef, Capacities, MutableArrayData, Scalar}; | ||
use arrow::datatypes::DataType; | ||
|
||
use datafusion_common::cast::{as_map_array, as_struct_array}; | ||
use datafusion_common::format::DEFAULT_CAST_OPTIONS; | ||
use datafusion_common::{ | ||
exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, | ||
}; | ||
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; | ||
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
#[derive(Debug)] | ||
pub struct GetFieldFunc { | ||
|
@@ -47,7 +48,6 @@ impl GetFieldFunc { | |
} | ||
} | ||
|
||
// get_field(struct_array, field_name) | ||
impl ScalarUDFImpl for GetFieldFunc { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
|
@@ -184,9 +184,25 @@ impl ScalarUDFImpl for GetFieldFunc { | |
}; | ||
|
||
match (array.data_type(), name) { | ||
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { | ||
(DataType::Map(_, _), name) => { | ||
let map_array = as_map_array(array.as_ref())?; | ||
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()])); | ||
if !matches!(name, ScalarValue::Utf8(_) | ScalarValue::Int64(_) | ScalarValue::Float64(_)) { | ||
return exec_err!( | ||
"get indexed field is only possible on map with utf8, int64 and float64 indexes. \ | ||
Tried with {name:?} index") | ||
} | ||
|
||
let mut key_array: ArrayRef = name.to_array()?; | ||
if key_array.data_type() != map_array.key_type() { | ||
let pre_cast_dt = key_array.data_type().clone(); | ||
if arrow::compute::kernels::cast::can_cast_types(key_array.data_type(), map_array.key_type()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting and type check should be handled in Signature not |
||
key_array = arrow::compute::kernels::cast::cast_with_options(&key_array, map_array.key_type(), &DEFAULT_CAST_OPTIONS)?; | ||
} | ||
if key_array.null_count() > 0{ | ||
return exec_err!("Could not convert {} {} to {}", pre_cast_dt, name, map_array.key_type()) | ||
} | ||
} | ||
let key_scalar = Scalar::new(key_array); | ||
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; | ||
|
||
// note that this array has more entries than the expected output/input size | ||
|
@@ -195,16 +211,16 @@ impl ScalarUDFImpl for GetFieldFunc { | |
let capacity = Capacities::Array(original_data.len()); | ||
let mut mutable = | ||
MutableArrayData::with_capacities(vec![&original_data], true, | ||
capacity); | ||
capacity); | ||
|
||
for entry in 0..map_array.len(){ | ||
let start = map_array.value_offsets()[entry] as usize; | ||
let end = map_array.value_offsets()[entry + 1] as usize; | ||
|
||
let maybe_matched = | ||
keys.slice(start, end-start). | ||
iter().enumerate(). | ||
find(|(_, t)| t.unwrap()); | ||
keys.slice(start, end-start). | ||
iter().enumerate(). | ||
find(|(_, t)| t.unwrap()); | ||
if maybe_matched.is_none(){ | ||
mutable.extend_nulls(1); | ||
continue | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -482,23 +482,55 @@ SELECT MAP { 'a': 1, 2: 3 }; | |
---- | ||
{a: 1, 2: 3} | ||
|
||
# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key | ||
# query ? | ||
# SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1]; | ||
# ---- | ||
# a | ||
|
||
query T | ||
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1]; | ||
---- | ||
a | ||
|
||
# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key | ||
# query ? | ||
# SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }[MAP {1:'a', 2:'b'}]; | ||
# ---- | ||
# 1 | ||
|
||
# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key | ||
# query ? | ||
# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2]; | ||
# ---- | ||
# 33 | ||
query I | ||
SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2]; | ||
---- | ||
33 | ||
|
||
query T | ||
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2.0]; | ||
---- | ||
b | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current implmentation will throw an error. Just checked DuckDB, it try to cast the index to keys data type. |
||
query T | ||
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }['2.0']; | ||
---- | ||
b | ||
|
||
query T | ||
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }[2]; | ||
---- | ||
b | ||
|
||
query T | ||
SELECT MAP { 1.0: 'a', 2.0: 'b', 3.0: 'c' }['2']; | ||
---- | ||
b | ||
|
||
query T | ||
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[3]; | ||
---- | ||
c | ||
|
||
query T | ||
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }['3']; | ||
---- | ||
c | ||
|
||
query error DataFusion error: Execution error: Could not convert Utf8 3\.0 to Int64 | ||
SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }['3.0']; | ||
|
||
## cardinality | ||
|
||
|
@@ -572,4 +604,4 @@ statement ok | |
drop table map_array_table_1; | ||
|
||
statement ok | ||
drop table map_array_table_2; | ||
drop table map_array_table_2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use
map_extract
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
map_extract
returns a list containing the value. Maybe can wrap it aroundget_field
call and extract the value. But i am not sure, let me try it out.