Skip to content

Commit

Permalink
Add string aggregagte grouping fuzz test (#9190)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored Feb 12, 2024
1 parent d7dcb12 commit 3c2b542
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 118 deletions.
60 changes: 53 additions & 7 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::{
not_impl_err, plan_err, Constraints, DataFusionError, SchemaExt,
not_impl_err, plan_err, Constraints, DFSchema, DataFusionError, SchemaExt,
};
use datafusion_execution::TaskContext;
use parking_lot::Mutex;
use tokio::sync::RwLock;
use tokio::task::JoinSet;

Expand All @@ -44,6 +45,7 @@ use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
use crate::physical_planner::create_physical_sort_expr;

/// Type alias for partition data
pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
Expand All @@ -58,6 +60,9 @@ pub struct MemTable {
pub(crate) batches: Vec<PartitionData>,
constraints: Constraints,
column_defaults: HashMap<String, Expr>,
/// Optional pre-known sort order(s). Must be `SortExpr`s.
/// inserting data into this table removes the order
pub sort_order: Arc<Mutex<Vec<Vec<Expr>>>>,
}

impl MemTable {
Expand All @@ -82,6 +87,7 @@ impl MemTable {
.collect::<Vec<_>>(),
constraints: Constraints::empty(),
column_defaults: HashMap::new(),
sort_order: Arc::new(Mutex::new(vec![])),
})
}

Expand All @@ -100,6 +106,21 @@ impl MemTable {
self
}

/// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
///
/// If the data is not sorted by this order, DataFusion may produce
/// incorrect results.
///
/// DataFusion may take advantage of this ordering to omit sorts
/// or use more efficient algorithms.
///
/// Note that multiple sort orders are supported, if some are known to be
/// equivalent,
pub fn with_sort_order(self, mut sort_order: Vec<Vec<Expr>>) -> Self {
std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
self
}

/// Create a mem table by reading from another data source
pub async fn load(
t: Arc<dyn TableProvider>,
Expand Down Expand Up @@ -184,7 +205,7 @@ impl TableProvider for MemTable {

async fn scan(
&self,
_state: &SessionState,
state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand All @@ -194,11 +215,33 @@ impl TableProvider for MemTable {
let inner_vec = arc_inner_vec.read().await;
partitions.push(inner_vec.clone())
}
Ok(Arc::new(MemoryExec::try_new(
&partitions,
self.schema(),
projection.cloned(),
)?))
let mut exec =
MemoryExec::try_new(&partitions, self.schema(), projection.cloned())?;

// add sort information if present
let sort_order = self.sort_order.lock();
if !sort_order.is_empty() {
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;

let file_sort_order = sort_order
.iter()
.map(|sort_exprs| {
sort_exprs
.iter()
.map(|expr| {
create_physical_sort_expr(
expr,
&df_schema,
state.execution_props(),
)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
exec = exec.with_sort_information(file_sort_order);
}

Ok(Arc::new(exec))
}

/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
Expand All @@ -219,6 +262,9 @@ impl TableProvider for MemTable {
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
// If we are inserting into the table, any sort order may be messed up so reset it here
*self.sort_order.lock() = vec![];

// Create a physical plan from the logical plan.
// Check that the schema of the plan matches the schema of this table.
if !self
Expand Down
190 changes: 178 additions & 12 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use arrow_array::cast::AsArray;
use arrow_array::types::Int64Type;
use arrow_array::Array;
use hashbrown::HashMap;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tokio::task::JoinSet;

use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::{collect, displayable, ExecutionPlan};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion};
use datafusion_physical_expr::expressions::{col, Sum};
use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
use datafusion_physical_plan::InputOrderMode;
use test_utils::{add_empty_batches, StringBatchGenerator};

#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn aggregate_test() {
/// Tests that streaming aggregate and batch (non streaming) aggregate produce
/// same results
#[tokio::test(flavor = "multi_thread")]
async fn streaming_aggregate_test() {
let test_cases = vec![
vec!["a"],
vec!["b", "a"],
Expand All @@ -50,18 +61,18 @@ async fn aggregate_test() {
let n = 300;
let distincts = vec![10, 20];
for distinct in distincts {
let mut handles = Vec::new();
let mut join_set = JoinSet::new();
for i in 0..n {
let test_idx = i % test_cases.len();
let group_by_columns = test_cases[test_idx].clone();
let job = tokio::spawn(run_aggregate_test(
join_set.spawn(run_aggregate_test(
make_staggered_batches::<true>(1000, distinct, i as u64),
group_by_columns,
));
handles.push(job);
}
for job in handles {
job.await.unwrap();
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}
}
Expand Down Expand Up @@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>(
}
add_empty_batches(batches, &mut rng)
}

/// Test group by with string/large string columns
#[tokio::test(flavor = "multi_thread")]
async fn group_by_strings() {
let mut join_set = JoinSet::new();
for large in [true, false] {
for sorted in [true, false] {
for generator in StringBatchGenerator::interesting_cases() {
join_set.spawn(group_by_string_test(generator, sorted, large));
}
}
}
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}

/// Run GROUP BY <x> using SQL and ensure the results are correct
///
/// If sorted is true, the input batches will be sorted by the group by column
/// to test the streaming group by case
///
/// if large is true, the input batches will be LargeStringArray
async fn group_by_string_test(
mut generator: StringBatchGenerator,
sorted: bool,
large: bool,
) {
let column_name = "a";
let input = if sorted {
generator.make_sorted_input_batches(large)
} else {
generator.make_input_batches()
};

let expected = compute_counts(&input, column_name);

let schema = input[0].schema();
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::new_with_config(session_config);

let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap();
let provider = if sorted {
let sort_expr = datafusion::prelude::col("a").sort(true, true);
provider.with_sort_order(vec![vec![sort_expr]])
} else {
provider
};

ctx.register_table("t", Arc::new(provider)).unwrap();

let df = ctx
.sql("SELECT a, COUNT(*) FROM t GROUP BY a")
.await
.unwrap();
verify_ordered_aggregate(&df, sorted).await;
let results = df.collect().await.unwrap();

// verify that the results are correct
let actual = extract_result_counts(results);
assert_eq!(expected, actual);
}
async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) {
struct Visitor {
expected_sort: bool,
}
let mut visitor = Visitor { expected_sort };

impl TreeNodeVisitor for Visitor {
type N = Arc<dyn ExecutionPlan>;
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
if let Some(exec) = node.as_any().downcast_ref::<AggregateExec>() {
if self.expected_sort {
assert!(matches!(
exec.input_order_mode(),
InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted
));
} else {
assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear));
}
}
Ok(VisitRecursion::Continue)
}
}

let plan = frame.clone().create_physical_plan().await.unwrap();
plan.visit(&mut visitor).unwrap();
}

/// Compute the count of each distinct value in the specified column
///
/// ```text
/// +---------------+---------------+
/// | a | b |
/// +---------------+---------------+
/// | 𭏷񬝜󓴻𼇪󄶛𑩁򽵐󦊟 | 󺚤𘱦𫎛񐕿 |
/// | 󂌿󶴬񰶨񺹭𿑵󖺉 | 񥼧􋽐󮋋󑤐𬿪𜋃 |
/// ```
fn compute_counts(batches: &[RecordBatch], col: &str) -> HashMap<Option<String>, i64> {
let mut output = HashMap::new();
for arr in batches
.iter()
.map(|batch| batch.column_by_name(col).unwrap())
{
for value in to_str_vec(arr) {
output.entry(value).and_modify(|e| *e += 1).or_insert(1);
}
}
output
}

fn to_str_vec(array: &ArrayRef) -> Vec<Option<String>> {
match array.data_type() {
DataType::Utf8 => array
.as_string::<i32>()
.iter()
.map(|x| x.map(|x| x.to_string()))
.collect(),
DataType::LargeUtf8 => array
.as_string::<i64>()
.iter()
.map(|x| x.map(|x| x.to_string()))
.collect(),
_ => panic!("unexpected type"),
}
}

/// extracts the value of the first column and the count of the second column
/// ```text
/// +----------------+----------+
/// | a | COUNT(*) |
/// +----------------+----------+
/// | 񩢰񴠍 | 8 |
/// | 󇿺򷜄򩨝񜖫𑟑񣶏󣥽𹕉 | 11 |
/// ```
fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i64> {
let group_arrays = results.iter().map(|batch| batch.column(0));

let count_arrays = results
.iter()
.map(|batch| batch.column(1).as_primitive::<Int64Type>());

let mut output = HashMap::new();
for (group_arr, count_arr) in group_arrays.zip(count_arrays) {
assert_eq!(group_arr.len(), count_arr.len());
let group_values = to_str_vec(group_arr);
for (group, count) in group_values.into_iter().zip(count_arr.iter()) {
assert!(output.get(&group).is_none());
let count = count.unwrap(); // counts can never be null
output.insert(group, count);
}
}
output
}
Loading

0 comments on commit 3c2b542

Please sign in to comment.