diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 4b8a4f71fe6..8e827afe19b 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -619,7 +619,9 @@ mod tests { datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; - use arrow::array::{ArrayRef, Float64Array, Int32Array}; + use arrow::array::{ + Array, ArrayRef, DictionaryArray, Float64Array, Int32Array, Int64Array, + }; use arrow::compute::add; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; @@ -1271,6 +1273,83 @@ mod tests { Ok(()) } + #[tokio::test] + async fn group_by_dictionary() { + async fn run_test_case() { + let mut ctx = ExecutionContext::new(); + + // input data looks like: + // A, 1 + // B, 2 + // A, 2 + // A, 4 + // C, 1 + // A, 1 + + let dict_array: DictionaryArray = + vec!["A", "B", "A", "A", "C", "A"].into_iter().collect(); + let dict_array = Arc::new(dict_array); + + let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into(); + let val_array = Arc::new(val_array); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict", dict_array.data_type().clone(), false), + Field::new("val", val_array.data_type().clone(), false), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![dict_array, val_array]) + .unwrap(); + + let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); + ctx.register_table("t", Box::new(provider)); + + let results = plan_and_collect( + &mut ctx, + "SELECT dict, count(val) FROM t GROUP BY dict", + ) + .await + .expect("ran plan correctly"); + + let expected = vec![ + "+------+------------+", + "| dict | COUNT(val) |", + "+------+------------+", + "| A | 4 |", + "| B | 1 |", + "| C | 1 |", + "+------+------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Now, use dict as an aggregate + let results = + plan_and_collect(&mut ctx, "SELECT val, count(dict) FROM t GROUP BY val") + .await + .expect("ran plan correctly"); + + let expected = vec![ + "+-----+-------------+", + "| val | COUNT(dict) |", + "+-----+-------------+", + "| 1 | 3 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+-------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + } + + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + run_test_case::().await; + } + async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 54d25f17dba..4f885cd75dc 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -31,7 +31,6 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr}; use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning, PhysicalExpr}; -use arrow::array::{BooleanArray, Date32Array}; use arrow::{ array::{Array, UInt32Builder}, error::{ArrowError, Result as ArrowResult}, @@ -43,6 +42,14 @@ use arrow::{ }, compute, }; +use arrow::{ + array::{BooleanArray, Date32Array, DictionaryArray}, + compute::cast, + datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, record_batch::RecordBatch, @@ -398,97 +405,165 @@ fn group_aggregate_batch( Ok(accumulators) } -/// Create a key `Vec` that is used as key for the hashmap -pub(crate) fn create_key( - group_by_keys: &[ArrayRef], +/// Appends a sequence of [u8] bytes for the value in `col[row]` to +/// `vec` to be used as a key into the hash map for a dictionary type +/// +/// Note that ideally, for dictionary encoded columns, we would be +/// able to simply use the dictionary idicies themselves (no need to +/// look up values) or possibly simply build the hash table entirely +/// on the dictionary indexes. +/// +/// This aproach would likely work (very) well for the common case, +/// but it also has to to handle the case where the dictionary itself +/// is not the same across all record batches (and thus indexes in one +/// record batch may not correspond to the same index in another) +fn dictionary_create_key_for_col( + col: &ArrayRef, row: usize, vec: &mut Vec, ) -> Result<()> { - vec.clear(); - for col in group_by_keys { - match col.data_type() { - DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&[array.value(row) as u8]); - } - DataType::Float32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::Float64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::UInt8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::UInt16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::UInt32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::UInt64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } + let dict_col = col.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_col.keys_array(); + let values_index = keys_col.value(row).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + create_key_for_col(&dict_col.values(), values_index, vec) +} + +/// Appends a sequence of [u8] bytes for the value in `col[row]` to +/// `vec` to be used as a key into the hash map +fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<()> { + match col.data_type() { + DataType::Boolean => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&[array.value(row) as u8]); + } + DataType::Float32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Float64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let array = col + .as_any() + .downcast_ref::() + .unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + let array = col + .as_any() + .downcast_ref::() + .unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + let value = array.value(row); + // store the size + vec.extend_from_slice(&value.len().to_le_bytes()); + // store the string value + vec.extend_from_slice(value.as_bytes()); + } + DataType::Date32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); + dictionary_create_key_for_col::(col, row, vec)?; } DataType::Int16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend(array.value(row).to_le_bytes().iter()); + dictionary_create_key_for_col::(col, row, vec)?; } DataType::Int32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); + dictionary_create_key_for_col::(col, row, vec)?; } DataType::Int64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); + dictionary_create_key_for_col::(col, row, vec)?; } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - let array = col - .as_any() - .downcast_ref::() - .unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); - } - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - let array = col - .as_any() - .downcast_ref::() - .unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); + DataType::UInt8 => { + dictionary_create_key_for_col::(col, row, vec)?; } - DataType::Utf8 => { - let array = col.as_any().downcast_ref::().unwrap(); - let value = array.value(row); - // store the size - vec.extend_from_slice(&value.len().to_le_bytes()); - // store the string value - vec.extend_from_slice(value.as_bytes()); + DataType::UInt16 => { + dictionary_create_key_for_col::(col, row, vec)?; } - DataType::Date32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec.extend_from_slice(&array.value(row).to_le_bytes()); + DataType::UInt32 => { + dictionary_create_key_for_col::(col, row, vec)?; } - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported GROUP BY for {}", - col.data_type(), - ))); + DataType::UInt64 => { + dictionary_create_key_for_col::(col, row, vec)?; } + _ => return Err(DataFusionError::Internal(format!( + "Unsupported GROUP BY type (dictionary index type not supported creating key) {}", + col.data_type(), + ))), + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported GROUP BY type creating key {}", + col.data_type(), + ))); } } Ok(()) } +/// Create a key `Vec` that is used as key for the hashmap +pub(crate) fn create_key( + group_by_keys: &[ArrayRef], + row: usize, + vec: &mut Vec, +) -> Result<()> { + vec.clear(); + for col in group_by_keys { + create_key_for_col(col, row, vec)? + } + Ok(()) +} + async fn compute_grouped_hash_aggregate( mode: AggregateMode, schema: SchemaRef, @@ -860,6 +935,16 @@ fn create_batch_from_map( let batch = if !arrays.is_empty() { // 5. let columns = concatenate(arrays)?; + + // cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + let columns = columns + .iter() + .zip(output_schema.fields().iter()) + .map(|(col, desired_field)| cast(col, desired_field.data_type())) + .collect::>>()?; + RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)? } else { RecordBatch::new_empty(Arc::new(output_schema.to_owned())) @@ -906,90 +991,124 @@ fn finalize_aggregation( } } -/// Create a Box<[GroupByScalar]> for the group by values +/// Extract the value in `col[row]` from a dictionary a GroupByScalar +fn dictionary_create_group_by_value( + col: &ArrayRef, + row: usize, +) -> Result { + let dict_col = col.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_col.keys_array(); + let values_index = keys_col.value(row).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + create_group_by_value(&dict_col.values(), values_index) +} + +/// Extract the value in `col[row]` as a GroupByScalar +fn create_group_by_value(col: &ArrayRef, row: usize) -> Result { + match col.data_type() { + DataType::Float32 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Float32(OrderedFloat::from(array.value(row)))) + } + DataType::Float64 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Float64(OrderedFloat::from(array.value(row)))) + } + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt8(array.value(row))) + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt16(array.value(row))) + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt32(array.value(row))) + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt64(array.value(row))) + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int8(array.value(row))) + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int16(array.value(row))) + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int32(array.value(row))) + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int64(array.value(row))) + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Utf8(Box::new(array.value(row).into()))) + } + DataType::Boolean => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Boolean(array.value(row))) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let array = col + .as_any() + .downcast_ref::() + .unwrap(); + Ok(GroupByScalar::TimeMicrosecond(array.value(row))) + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + let array = col + .as_any() + .downcast_ref::() + .unwrap(); + Ok(GroupByScalar::TimeNanosecond(array.value(row))) + } + DataType::Date32 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Date32(array.value(row))) + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => dictionary_create_group_by_value::(col, row), + DataType::Int16 => dictionary_create_group_by_value::(col, row), + DataType::Int32 => dictionary_create_group_by_value::(col, row), + DataType::Int64 => dictionary_create_group_by_value::(col, row), + DataType::UInt8 => dictionary_create_group_by_value::(col, row), + DataType::UInt16 => dictionary_create_group_by_value::(col, row), + DataType::UInt32 => dictionary_create_group_by_value::(col, row), + DataType::UInt64 => dictionary_create_group_by_value::(col, row), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported GROUP BY type (dictionary index type not supported) {}", + col.data_type(), + ))), + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported GROUP BY type {}", + col.data_type(), + ))), + } +} + +/// Extract the values in `group_by_keys` arrow arrays into the target vector +/// as GroupByScalar values pub(crate) fn create_group_by_values( group_by_keys: &[ArrayRef], row: usize, vec: &mut Box<[GroupByScalar]>, ) -> Result<()> { - for i in 0..group_by_keys.len() { - let col = &group_by_keys[i]; - match col.data_type() { - DataType::Float32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Float32(OrderedFloat::from(array.value(row))) - } - DataType::Float64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Float64(OrderedFloat::from(array.value(row))) - } - DataType::UInt8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt8(array.value(row)) - } - DataType::UInt16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt16(array.value(row)) - } - DataType::UInt32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt32(array.value(row)) - } - DataType::UInt64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt64(array.value(row)) - } - DataType::Int8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int8(array.value(row)) - } - DataType::Int16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int16(array.value(row)) - } - DataType::Int32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int32(array.value(row)) - } - DataType::Int64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int64(array.value(row)) - } - DataType::Utf8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(Box::new(array.value(row).into())) - } - DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Boolean(array.value(row)) - } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - let array = col - .as_any() - .downcast_ref::() - .unwrap(); - vec[i] = GroupByScalar::TimeMicrosecond(array.value(row)) - } - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - let array = col - .as_any() - .downcast_ref::() - .unwrap(); - vec[i] = GroupByScalar::TimeNanosecond(array.value(row)) - } - DataType::Date32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Date32(array.value(row)); - } - - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported GROUP BY for {}", - col.data_type(), - ))); - } - } + for (i, col) in group_by_keys.iter().enumerate() { + vec[i] = create_group_by_value(col, row)? } Ok(()) }