Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 225 additions & 3 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,16 @@ impl StaticFilter for ArrayStaticFilter {
));
}

// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against distinct values and mapping back
// via keys.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against distinct values and mapping back
// via keys.
// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against the dictionary values and mapping back
// via keys.

downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()))
if v.values().data_type() == self.in_array.data_type() {
Comment thread
alamb marked this conversation as resolved.
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()));
}
Comment thread
alamb marked this conversation as resolved.
}
_ => {}
}
Expand Down Expand Up @@ -3878,4 +3883,221 @@ mod tests {
);
Ok(())
}

// -----------------------------------------------------------------------
// Tests for try_new_from_array covering all (in_array, needle) type
// combinations that occur in HashJoin dynamic filter pushdown.
//
// try_new (used by SQL IN expressions) always produces a non-Dictionary
// in_array because evaluate_list() flattens Dictionary scalars to their
// value type. try_new_from_array passes the array directly, so it is
// the only path that can produce a Dictionary in_array.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused about this comment -- it implies that the Dictionary is on the in_list side, but the code fix above is handling a Dictionary needle (v) and a non Dictionary Haystack 🤔

I think it may be worthwhile to consider changing the code so that it doesn't support a Dictionary in in_array -- but rather normalizes the haystack (we can do this as a follow on PR)

// -----------------------------------------------------------------------

fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
Arc::new(DictionaryArray::new(keys, array))
}

fn eval_in_list_from_array(
Comment thread
alamb marked this conversation as resolved.
needle_type: DataType,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you need the type of needle can't w just call needle.data_type and that way be sure it is consistent with the needle array? That would make the tests less fragile, likely shorter, and easier to validate they are doing the right thing

needle: ArrayRef,
in_array: ArrayRef,
) -> Result<BooleanArray> {
let schema = Schema::new(vec![Field::new("a", needle_type, false)]);
let col_a = col("a", &schema)?;
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
as Arc<dyn PhysicalExpr>;
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
Ok(as_boolean_array(&result).clone())
}

#[test]
fn test_in_list_from_array_type_combinations() -> Result<()> {
use arrow::compute::cast;

// All cases: needle[0] and needle[2] match, needle[1] does not.
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]);

// Base arrays cast to each target type
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;

// Test all specializations in instantiate_static_filter
let primitive_types = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];

for dt in &primitive_types {
let in_array = cast(&base_in, dt)?;
let needle = cast(&base_needle, dt)?;

// T in_array, T needle
assert_eq!(
expected,
eval_in_list_from_array(
dt.clone(),
Arc::clone(&needle),
Arc::clone(&in_array),
)?,
"same-type failed for {dt:?}"
);

// T in_array, Dict(Int32, T) needle
let dict_dt =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(dt.clone()));
assert_eq!(
expected,
eval_in_list_from_array(dict_dt, wrap_in_dict(needle), in_array,)?,
"dict-needle failed for {dt:?}"
);
}

// Utf8 (falls through to ArrayStaticFilter)
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
let dict_utf8 =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));

// Utf8 in_array, Utf8 needle
assert_eq!(
expected,
eval_in_list_from_array(
DataType::Utf8,
Arc::clone(&utf8_needle),
Arc::clone(&utf8_in),
)?
);

// Utf8 in_array, Dict(Utf8) needle
assert_eq!(
expected,
eval_in_list_from_array(
dict_utf8.clone(),
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::clone(&utf8_in),
)?
);

// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
assert_eq!(
expected,
eval_in_list_from_array(
dict_utf8,
wrap_in_dict(Arc::clone(&utf8_needle)),
wrap_in_dict(Arc::clone(&utf8_in)),
)?
);

// Struct in_array, Struct needle: multi-column join
let struct_fields = Fields::from(vec![
Field::new("c0", DataType::Utf8, true),
Field::new("c1", DataType::Int64, true),
]);
let struct_type = DataType::Struct(struct_fields.clone());
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
struct_type,
make_struct(
Arc::clone(&utf8_needle),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_struct(
Arc::clone(&utf8_in),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

// Struct with Dict fields: multi-column Dict join
let dict_struct_fields = Fields::from(vec![
Field::new(
"c0",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
),
Field::new("c1", DataType::Int64, true),
]);
let dict_struct_type = DataType::Struct(dict_struct_fields.clone());
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
let pairs: Vec<(FieldRef, ArrayRef)> =
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
Arc::new(StructArray::from(pairs))
};
assert_eq!(
expected,
eval_in_list_from_array(
dict_struct_type,
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_needle)),
Arc::new(Int64Array::from(vec![1, 4, 2])),
),
make_dict_struct(
wrap_in_dict(Arc::clone(&utf8_in)),
Arc::new(Int64Array::from(vec![1, 2, 3])),
),
)?
);

Ok(())
}

#[test]
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
// Utf8 needle, Dict(Utf8) in_array
let err = eval_in_list_from_array(
DataType::Utf8,
Arc::new(StringArray::from(vec!["a", "d", "b"])),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);

// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
// rejects the Utf8 dictionary values at construction time
let err = eval_in_list_from_array(
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
Arc::new(Int64Array::from(vec![1, 2, 3])),
)
.unwrap_err()
.to_string();
assert!(err.contains("Failed to downcast"), "{err}");

// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
// value types, make_comparator rejects the comparison
let err = eval_in_list_from_array(
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int64)),
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
)
.unwrap_err()
.to_string();
assert!(
err.contains("Can't compare arrays of different types"),
"{err}"
);

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug
STORED AS PARQUET
LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet';

query error Can't compare arrays of different types
query TR
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

SELECT t.tag1, t.value
FROM dict_filter_bug t
JOIN (VALUES ('A'), ('B')) AS v(c1)
ON t.tag1 = v.c1
ORDER BY t.tag1, t.value
LIMIT 4;
----
A 0
A 26
A 52
A 78

# Cleanup
statement ok
Expand Down
Loading