-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Produce correct answers for Group BY NULL (Option 1) #793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ use arrow::{ | |
| }, | ||
| }; | ||
| use ordered_float::OrderedFloat; | ||
| use std::convert::Infallible; | ||
| use std::convert::{Infallible, TryInto}; | ||
| use std::str::FromStr; | ||
| use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; | ||
|
|
||
|
|
@@ -796,6 +796,11 @@ impl ScalarValue { | |
|
|
||
| /// Converts a value in `array` at `index` into a ScalarValue | ||
| pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> { | ||
| // handle NULL value | ||
| if !array.is_valid(index) { | ||
| return array.data_type().try_into(); | ||
| } | ||
|
|
||
| Ok(match array.data_type() { | ||
| DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), | ||
| DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), | ||
|
|
@@ -897,6 +902,7 @@ impl ScalarValue { | |
| let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); | ||
|
|
||
| // look up the index in the values dictionary | ||
| // (note validity was previously checked in `try_from_array`) | ||
| let keys_col = dict_array.keys(); | ||
| let values_index = keys_col.value(index).to_usize().ok_or_else(|| { | ||
| DataFusionError::Internal(format!( | ||
|
|
@@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool); | |
| impl TryFrom<&DataType> for ScalarValue { | ||
| type Error = DataFusionError; | ||
|
|
||
| /// Create a Null instance of ScalarValue for this datatype | ||
| fn try_from(datatype: &DataType) -> Result<Self> { | ||
| Ok(match datatype { | ||
| DataType::Boolean => ScalarValue::Boolean(None), | ||
|
|
@@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue { | |
| DataType::Timestamp(TimeUnit::Nanosecond, _) => { | ||
| ScalarValue::TimestampNanosecond(None) | ||
| } | ||
| DataType::Dictionary(_index_type, value_type) => { | ||
| value_type.as_ref().try_into()? | ||
| } | ||
| DataType::List(ref nested_type) => { | ||
| ScalarValue::List(None, Box::new(nested_type.data_type().clone())) | ||
| } | ||
| _ => { | ||
| return Err(DataFusionError::NotImplemented(format!( | ||
| "Can't create a scalar of type \"{:?}\"", | ||
| "Can't create a scalar from data_type \"{:?}\"", | ||
| datatype | ||
| ))) | ||
| } | ||
|
|
@@ -1535,6 +1545,29 @@ mod tests { | |
| "{}", result); | ||
| } | ||
|
|
||
| #[test] | ||
| fn scalar_try_from_array_null() { | ||
| let array = vec![Some(33), None].into_iter().collect::<Int64Array>(); | ||
| let array: ArrayRef = Arc::new(array); | ||
|
|
||
| assert_eq!( | ||
| ScalarValue::Int64(Some(33)), | ||
| ScalarValue::try_from_array(&array, 0).unwrap() | ||
| ); | ||
| assert_eq!( | ||
| ScalarValue::Int64(None), | ||
| ScalarValue::try_from_array(&array, 1).unwrap() | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn scalar_try_from_dict_datatype() { | ||
| let data_type = | ||
| DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); | ||
| let data_type = &data_type; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🥳
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amusingly, supporting this behavior ended up causing a test to fail when I brought the code into IOx and I think I traced the problem to an issue in parquet file statistics: apache/arrow-rs#641 🤣 this was not a side effect I had anticipated |
||
| assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn size_of_scalar() { | ||
| // Since ScalarValues are used in a non trivial number of places, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3014,6 +3014,109 @@ async fn query_count_distinct() -> Result<()> { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn query_group_on_null() -> Result<()> { | ||
| let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); | ||
|
|
||
| let data = RecordBatch::try_new( | ||
| schema.clone(), | ||
| vec![Arc::new(Int32Array::from(vec![ | ||
| Some(0), | ||
| Some(3), | ||
| None, | ||
| Some(1), | ||
| Some(3), | ||
| ]))], | ||
| )?; | ||
|
|
||
| let table = MemTable::try_new(schema, vec![vec![data]])?; | ||
|
|
||
| let mut ctx = ExecutionContext::new(); | ||
| ctx.register_table("test", Arc::new(table))?; | ||
| let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; | ||
|
|
||
| let actual = execute_to_batches(&mut ctx, sql).await; | ||
|
|
||
| // Note that the results also | ||
| // include a row for NULL (c1=NULL, count = 1) | ||
| let expected = vec![ | ||
| "+-----------------+----+", | ||
| "| COUNT(UInt8(1)) | c1 |", | ||
| "+-----------------+----+", | ||
| "| 1 | |", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| "| 1 | 0 |", | ||
| "| 1 | 1 |", | ||
| "| 2 | 3 |", | ||
| "+-----------------+----+", | ||
| ]; | ||
| assert_batches_sorted_eq!(expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn query_group_on_null_multi_col() -> Result<()> { | ||
| let schema = Arc::new(Schema::new(vec![ | ||
| Field::new("c1", DataType::Int32, true), | ||
| Field::new("c2", DataType::Utf8, true), | ||
| ])); | ||
|
|
||
| let data = RecordBatch::try_new( | ||
| schema.clone(), | ||
| vec![ | ||
| Arc::new(Int32Array::from(vec![ | ||
| Some(0), | ||
| Some(0), | ||
| Some(3), | ||
| None, | ||
| None, | ||
| Some(3), | ||
| Some(0), | ||
| None, | ||
| Some(3), | ||
| ])), | ||
| Arc::new(StringArray::from(vec![ | ||
| None, | ||
| None, | ||
| Some("foo"), | ||
| None, | ||
| Some("bar"), | ||
| Some("foo"), | ||
| None, | ||
| Some("bar"), | ||
| Some("foo"), | ||
| ])), | ||
| ], | ||
| )?; | ||
|
|
||
| let table = MemTable::try_new(schema, vec![vec![data]])?; | ||
|
|
||
| let mut ctx = ExecutionContext::new(); | ||
| ctx.register_table("test", Arc::new(table))?; | ||
| let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; | ||
|
|
||
| let actual = execute_to_batches(&mut ctx, sql).await; | ||
|
|
||
| // Note that the results also include values for null | ||
| // include a row for NULL (c1=NULL, count = 1) | ||
| let expected = vec![ | ||
| "+-----------------+----+-----+", | ||
| "| COUNT(UInt8(1)) | c1 | c2 |", | ||
| "+-----------------+----+-----+", | ||
| "| 1 | | |", | ||
| "| 2 | | bar |", | ||
| "| 3 | 0 | |", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| "| 3 | 3 | foo |", | ||
| "+-----------------+----+-----+", | ||
| ]; | ||
| assert_batches_sorted_eq!(expected, &actual); | ||
|
|
||
| // Also run query with group columns reversed (results shoudl be the same) | ||
| let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; | ||
| let actual = execute_to_batches(&mut ctx, sql).await; | ||
| assert_batches_sorted_eq!(expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn query_on_string_dictionary() -> Result<()> { | ||
| // Test to ensure DataFusion can operate on dictionary types | ||
|
|
@@ -3067,6 +3170,13 @@ async fn query_on_string_dictionary() -> Result<()> { | |
| let expected = vec![vec!["2"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| // grouping | ||
| let sql = "SELECT d1, COUNT(*) FROM test group by d1"; | ||
| let mut actual = execute(&mut ctx, sql).await; | ||
| actual.sort(); | ||
| let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it makes sense to improve performance here, but an optimization might be to check on
null-count==0outside of this function to avoid theis_validcall and just always add an0xFFThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion.
If you don't mind I would like to spend time on #790 which, if successful, I expect to significantly remove all this code.
I will attempt to add that optimization at a later date.