Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,10 @@ fn group_aggregate_batch(
// We can safely unwrap here as we checked we can create an accumulator before
let accumulator_set = create_accumulators(aggr_expr).unwrap();
batch_keys.push(key.clone());
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
// Note it would be nice to make this a real error (rather than panic)
// but it is better than silently ignoring the issue and getting wrong results
create_group_by_values(&group_values, row, &mut group_by_values)
.expect("can not create group by value");
(
key.clone(),
(group_by_values.clone(), accumulator_set, vec![row as u32]),
Expand Down Expand Up @@ -508,7 +511,9 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
}

/// Appends a sequence of [u8] bytes for the value in `col[row]` to
/// `vec` to be used as a key into the hash map
/// `vec` to be used as a key into the hash map.
///
/// NOTE: This function does not check col.is_valid(). Caller must do so
fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<()> {
match col.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -640,14 +645,63 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
}

/// Create a key `Vec<u8>` that is used as key for the hashmap
///
/// This looks like
/// [null_byte][col_value_bytes][null_byte][col_value_bytes]
///
/// Note that relatively uncommon patterns (e.g. not 0x00) are chosen
/// for the null_byte to make debugging easier. The actual values are
/// arbitrary.
///
/// For a NULL value in a column, the key looks like
/// [0xFE]
///
/// For a Non-NULL value in a column, this looks like:
/// [0xFF][byte representation of column value]
///
/// Example of a key with no NULL values:
/// ```text
/// 0xFF byte at the start of each column
/// signifies the value is non-null
/// │
///
/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐
///
/// │ string len │ 0x1234
/// { ▼ (as usize le) "foo" ▼(as u16 le)
/// k1: "foo" ╔ ═┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──╦ ═┌──┬──┐
/// k2: 0x1234u16 FF║03│00│00│00│00│00│00│00│"f│"o│"o│FF║34│12│
/// } ╚ ═└──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──╩ ═└──┴──┘
/// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
/// ```
///
/// Example of a key with NULL values:
///
///```text
/// 0xFE byte at the start of k1 column
/// ┌ ─ signifies the value is NULL
///
/// └ ┐
/// 0x1234
/// { ▼ (as u16 le)
/// k1: NULL ╔ ═╔ ═┌──┬──┐
/// k2: 0x1234u16 FE║FF║12│34│
/// } ╚ ═╚ ═└──┴──┘
/// 0 1 2 3
///```
pub(crate) fn create_key(
group_by_keys: &[ArrayRef],
row: usize,
vec: &mut Vec<u8>,
) -> Result<()> {
vec.clear();
for col in group_by_keys {
create_key_for_col(col, row, vec)?
if !col.is_valid(row) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense to improve performance here, but an optimization might be to check on null-count==0 outside of this function to avoid the is_valid call and just always add an 0xFF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion.

If you don't mind I would like to spend time on #790 which, if successful, I expect to significantly remove all this code.

I will attempt to add that optimization at a later date.

vec.push(0xFE);
} else {
vec.push(0xFF);
create_key_for_col(col, row, vec)?
}
}
Ok(())
}
Expand Down
37 changes: 35 additions & 2 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::{
},
};
use ordered_float::OrderedFloat;
use std::convert::Infallible;
use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

Expand Down Expand Up @@ -796,6 +796,11 @@ impl ScalarValue {

/// Converts a value in `array` at `index` into a ScalarValue
pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
// handle NULL value
if !array.is_valid(index) {
return array.data_type().try_into();
}

Ok(match array.data_type() {
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
Expand Down Expand Up @@ -897,6 +902,7 @@ impl ScalarValue {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
// (note validity was previously checked in `try_from_array`)
let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
Expand Down Expand Up @@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool);
impl TryFrom<&DataType> for ScalarValue {
type Error = DataFusionError;

/// Create a Null instance of ScalarValue for this datatype
fn try_from(datatype: &DataType) -> Result<Self> {
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(None),
Expand Down Expand Up @@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
ScalarValue::TimestampNanosecond(None)
}
DataType::Dictionary(_index_type, value_type) => {
value_type.as_ref().try_into()?
}
DataType::List(ref nested_type) => {
ScalarValue::List(None, Box::new(nested_type.data_type().clone()))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar of type \"{:?}\"",
"Can't create a scalar from data_type \"{:?}\"",
datatype
)))
}
Expand Down Expand Up @@ -1535,6 +1545,29 @@ mod tests {
"{}", result);
}

#[test]
fn scalar_try_from_array_null() {
let array = vec![Some(33), None].into_iter().collect::<Int64Array>();
let array: ArrayRef = Arc::new(array);

assert_eq!(
ScalarValue::Int64(Some(33)),
ScalarValue::try_from_array(&array, 0).unwrap()
);
assert_eq!(
ScalarValue::Int64(None),
ScalarValue::try_from_array(&array, 1).unwrap()
);
}

#[test]
fn scalar_try_from_dict_datatype() {
let data_type =
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
let data_type = &data_type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥳

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amusingly, supporting this behavior ended up causing a test to fail when I brought the code into IOx and I think I traced the problem to an issue in parquet file statistics: apache/arrow-rs#641 🤣 this was not a side effect I had anticipated

assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap())
}

#[test]
fn size_of_scalar() {
// Since ScalarValues are used in a non trivial number of places,
Expand Down
110 changes: 110 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,109 @@ async fn query_count_distinct() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn query_group_on_null() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));

let data = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![
Some(0),
Some(3),
None,
Some(1),
Some(3),
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let mut ctx = ExecutionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1";

let actual = execute_to_batches(&mut ctx, sql).await;

// Note that the results also
// include a row for NULL (c1=NULL, count = 1)
let expected = vec![
"+-----------------+----+",
"| COUNT(UInt8(1)) | c1 |",
"+-----------------+----+",
"| 1 | |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"| 1 | 0 |",
"| 1 | 1 |",
"| 2 | 3 |",
"+-----------------+----+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn query_group_on_null_multi_col() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![
Some(0),
Some(0),
Some(3),
None,
None,
Some(3),
Some(0),
None,
Some(3),
])),
Arc::new(StringArray::from(vec![
None,
None,
Some("foo"),
None,
Some("bar"),
Some("foo"),
None,
Some("bar"),
Some("foo"),
])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let mut ctx = ExecutionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2";

let actual = execute_to_batches(&mut ctx, sql).await;

// Note that the results also include values for null
// include a row for NULL (c1=NULL, count = 1)
let expected = vec![
"+-----------------+----+-----+",
"| COUNT(UInt8(1)) | c1 | c2 |",
"+-----------------+----+-----+",
"| 1 | | |",
"| 2 | | bar |",
"| 3 | 0 | |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"| 3 | 3 | foo |",
"+-----------------+----+-----+",
];
assert_batches_sorted_eq!(expected, &actual);

// Also run query with group columns reversed (results shoudl be the same)
let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1";
let actual = execute_to_batches(&mut ctx, sql).await;
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn query_on_string_dictionary() -> Result<()> {
// Test to ensure DataFusion can operate on dictionary types
Expand Down Expand Up @@ -3067,6 +3170,13 @@ async fn query_on_string_dictionary() -> Result<()> {
let expected = vec![vec!["2"]];
assert_eq!(expected, actual);

// grouping
let sql = "SELECT d1, COUNT(*) FROM test group by d1";
let mut actual = execute(&mut ctx, sql).await;
actual.sort();
let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]];
assert_eq!(expected, actual);

Ok(())
}

Expand Down