-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix: InList Dictionary filter pushdown type mismatch #20962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
2bb9617
91e1e59
1a4764c
4fcc823
59f50cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,11 +99,16 @@ impl StaticFilter for ArrayStaticFilter { | |
| )); | ||
| } | ||
|
|
||
| // Unwrap dictionary-encoded needles when the value type matches | ||
| // in_array, evaluating against distinct values and mapping back | ||
| // via keys. | ||
| downcast_dictionary_array! { | ||
| v => { | ||
| let values_contains = self.contains(v.values().as_ref(), negated)?; | ||
| let result = take(&values_contains, v.keys(), None)?; | ||
| return Ok(downcast_array(result.as_ref())) | ||
| if v.values().data_type() == self.in_array.data_type() { | ||
|
alamb marked this conversation as resolved.
|
||
| let values_contains = self.contains(v.values().as_ref(), negated)?; | ||
| let result = take(&values_contains, v.keys(), None)?; | ||
| return Ok(downcast_array(result.as_ref())); | ||
| } | ||
|
alamb marked this conversation as resolved.
|
||
| } | ||
| _ => {} | ||
| } | ||
|
|
@@ -3878,4 +3883,221 @@ mod tests { | |
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| // ----------------------------------------------------------------------- | ||
| // Tests for try_new_from_array covering all (in_array, needle) type | ||
| // combinations that occur in HashJoin dynamic filter pushdown. | ||
| // | ||
| // try_new (used by SQL IN expressions) always produces a non-Dictionary | ||
| // in_array because evaluate_list() flattens Dictionary scalars to their | ||
| // value type. try_new_from_array passes the array directly, so it is | ||
| // the only path that can produce a Dictionary in_array. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little confused about this comment -- it implies that the Dictionary is on the in_list side, but the code fix above is handling a Dictionary needle ( I think it may be worthwhile to consider changing the code so that it doesn't support a Dictionary in in_array -- but rather normalizes the haystack (we can do this as a follow on PR) |
||
| // ----------------------------------------------------------------------- | ||
|
|
||
| fn wrap_in_dict(array: ArrayRef) -> ArrayRef { | ||
| let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>()); | ||
| Arc::new(DictionaryArray::new(keys, array)) | ||
| } | ||
|
|
||
| fn eval_in_list_from_array( | ||
|
alamb marked this conversation as resolved.
|
||
| needle_type: DataType, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you need the type of |
||
| needle: ArrayRef, | ||
| in_array: ArrayRef, | ||
| ) -> Result<BooleanArray> { | ||
| let schema = Schema::new(vec![Field::new("a", needle_type, false)]); | ||
| let col_a = col("a", &schema)?; | ||
| let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) | ||
| as Arc<dyn PhysicalExpr>; | ||
| let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| Ok(as_boolean_array(&result).clone()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_in_list_from_array_type_combinations() -> Result<()> { | ||
| use arrow::compute::cast; | ||
|
|
||
| // All cases: needle[0] and needle[2] match, needle[1] does not. | ||
| let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); | ||
|
|
||
| // Base arrays cast to each target type | ||
| let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef; | ||
| let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef; | ||
|
|
||
| // Test all specializations in instantiate_static_filter | ||
| let primitive_types = vec![ | ||
| DataType::Int8, | ||
| DataType::Int16, | ||
| DataType::Int32, | ||
| DataType::Int64, | ||
| DataType::UInt8, | ||
| DataType::UInt16, | ||
| DataType::UInt32, | ||
| DataType::UInt64, | ||
| DataType::Float32, | ||
| DataType::Float64, | ||
| ]; | ||
|
|
||
| for dt in &primitive_types { | ||
| let in_array = cast(&base_in, dt)?; | ||
| let needle = cast(&base_needle, dt)?; | ||
|
|
||
| // T in_array, T needle | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| dt.clone(), | ||
| Arc::clone(&needle), | ||
| Arc::clone(&in_array), | ||
| )?, | ||
| "same-type failed for {dt:?}" | ||
| ); | ||
|
|
||
| // T in_array, Dict(Int32, T) needle | ||
| let dict_dt = | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(dt.clone())); | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array(dict_dt, wrap_in_dict(needle), in_array,)?, | ||
| "dict-needle failed for {dt:?}" | ||
| ); | ||
| } | ||
|
|
||
| // Utf8 (falls through to ArrayStaticFilter) | ||
| let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; | ||
| let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef; | ||
| let dict_utf8 = | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); | ||
|
|
||
| // Utf8 in_array, Utf8 needle | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| DataType::Utf8, | ||
| Arc::clone(&utf8_needle), | ||
| Arc::clone(&utf8_in), | ||
| )? | ||
| ); | ||
|
|
||
| // Utf8 in_array, Dict(Utf8) needle | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| dict_utf8.clone(), | ||
| wrap_in_dict(Arc::clone(&utf8_needle)), | ||
| Arc::clone(&utf8_in), | ||
| )? | ||
| ); | ||
|
|
||
| // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| dict_utf8, | ||
| wrap_in_dict(Arc::clone(&utf8_needle)), | ||
| wrap_in_dict(Arc::clone(&utf8_in)), | ||
| )? | ||
| ); | ||
|
|
||
| // Struct in_array, Struct needle: multi-column join | ||
| let struct_fields = Fields::from(vec![ | ||
| Field::new("c0", DataType::Utf8, true), | ||
| Field::new("c1", DataType::Int64, true), | ||
| ]); | ||
| let struct_type = DataType::Struct(struct_fields.clone()); | ||
| let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { | ||
| let pairs: Vec<(FieldRef, ArrayRef)> = | ||
| struct_fields.iter().cloned().zip([c0, c1]).collect(); | ||
| Arc::new(StructArray::from(pairs)) | ||
| }; | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| struct_type, | ||
| make_struct( | ||
| Arc::clone(&utf8_needle), | ||
| Arc::new(Int64Array::from(vec![1, 4, 2])), | ||
| ), | ||
| make_struct( | ||
| Arc::clone(&utf8_in), | ||
| Arc::new(Int64Array::from(vec![1, 2, 3])), | ||
| ), | ||
| )? | ||
| ); | ||
|
|
||
| // Struct with Dict fields: multi-column Dict join | ||
| let dict_struct_fields = Fields::from(vec![ | ||
| Field::new( | ||
| "c0", | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), | ||
| true, | ||
| ), | ||
| Field::new("c1", DataType::Int64, true), | ||
| ]); | ||
| let dict_struct_type = DataType::Struct(dict_struct_fields.clone()); | ||
| let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { | ||
| let pairs: Vec<(FieldRef, ArrayRef)> = | ||
| dict_struct_fields.iter().cloned().zip([c0, c1]).collect(); | ||
| Arc::new(StructArray::from(pairs)) | ||
| }; | ||
| assert_eq!( | ||
| expected, | ||
| eval_in_list_from_array( | ||
| dict_struct_type, | ||
| make_dict_struct( | ||
| wrap_in_dict(Arc::clone(&utf8_needle)), | ||
| Arc::new(Int64Array::from(vec![1, 4, 2])), | ||
| ), | ||
| make_dict_struct( | ||
| wrap_in_dict(Arc::clone(&utf8_in)), | ||
| Arc::new(Int64Array::from(vec![1, 2, 3])), | ||
| ), | ||
| )? | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { | ||
| // Utf8 needle, Dict(Utf8) in_array | ||
| let err = eval_in_list_from_array( | ||
| DataType::Utf8, | ||
| Arc::new(StringArray::from(vec!["a", "d", "b"])), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!( | ||
| err.contains("Can't compare arrays of different types"), | ||
| "{err}" | ||
| ); | ||
|
|
||
| // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter | ||
| // rejects the Utf8 dictionary values at construction time | ||
| let err = eval_in_list_from_array( | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), | ||
| Arc::new(Int64Array::from(vec![1, 2, 3])), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!(err.contains("Failed to downcast"), "{err}"); | ||
|
|
||
| // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different | ||
| // value types, make_comparator rejects the comparison | ||
| let err = eval_in_list_from_array( | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int64)), | ||
| wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!( | ||
| err.contains("Can't compare arrays of different types"), | ||
| "{err}" | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug | |
| STORED AS PARQUET | ||
| LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; | ||
|
|
||
| query error Can't compare arrays of different types | ||
| query TR | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| SELECT t.tag1, t.value | ||
| FROM dict_filter_bug t | ||
| JOIN (VALUES ('A'), ('B')) AS v(c1) | ||
| ON t.tag1 = v.c1 | ||
| ORDER BY t.tag1, t.value | ||
| LIMIT 4; | ||
| ---- | ||
| A 0 | ||
| A 26 | ||
| A 52 | ||
| A 78 | ||
|
|
||
| # Cleanup | ||
| statement ok | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.