Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions rust/arrow/src/compute/kernels/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub fn concat(array_list: &[ArrayRef]) -> Result<ArrayRef> {
DataType::Duration(TimeUnit::Nanosecond) => {
concat_primitive::<DurationNanosecondType>(array_data_list)
}
DataType::List(nested_type) => concat_list(array_data_list, *nested_type.clone()),
t => Err(ArrowError::ComputeError(format!(
"Concat not supported for data type {:?}",
t
Expand All @@ -131,6 +132,37 @@ where
Ok(ArrayBuilder::finish(&mut builder))
}

#[inline]
fn concat_primitive_list<T>(array_data_list: &[ArrayDataRef]) -> Result<ArrayRef>
where
T: ArrowNumericType,
{
let mut builder = ListBuilder::new(PrimitiveArray::<T>::builder(0));
builder.append_data(array_data_list)?;
Ok(ArrayBuilder::finish(&mut builder))
}

#[inline]
fn concat_list(
array_data_list: &[ArrayDataRef],
data_type: DataType,
) -> Result<ArrayRef> {
match data_type {
DataType::Int8 => concat_primitive_list::<Int8Type>(array_data_list),
DataType::Int16 => concat_primitive_list::<Int16Type>(array_data_list),
DataType::Int32 => concat_primitive_list::<Int32Type>(array_data_list),
DataType::Int64 => concat_primitive_list::<Int64Type>(array_data_list),
DataType::UInt8 => concat_primitive_list::<UInt8Type>(array_data_list),
DataType::UInt16 => concat_primitive_list::<UInt16Type>(array_data_list),
DataType::UInt32 => concat_primitive_list::<UInt32Type>(array_data_list),
DataType::UInt64 => concat_primitive_list::<UInt64Type>(array_data_list),
t => Err(ArrowError::ComputeError(format!(
"Concat not supported for list with data type {:?}",
t
))),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -285,4 +317,75 @@ mod tests {

Ok(())
}

#[test]
fn test_concat_primitive_list_arrays() -> Result<()> {
fn populate_list1(
b: &mut ListBuilder<PrimitiveBuilder<Int64Type>>,
) -> Result<()> {
b.values().append_value(-1)?;
b.values().append_value(-1)?;
b.values().append_value(2)?;
b.values().append_null()?;
b.values().append_null()?;
b.append(true)?;
b.append(true)?;
b.append(false)?;
b.values().append_value(10)?;
b.append(true)?;
Ok(())
}

fn populate_list2(
b: &mut ListBuilder<PrimitiveBuilder<Int64Type>>,
) -> Result<()> {
b.append(false)?;
b.values().append_value(100)?;
b.values().append_null()?;
b.values().append_value(101)?;
b.append(true)?;
b.values().append_value(102)?;
b.append(true)?;
Ok(())
}

fn populate_list3(
b: &mut ListBuilder<PrimitiveBuilder<Int64Type>>,
) -> Result<()> {
b.values().append_value(1000)?;
b.values().append_value(1001)?;
b.append(true)?;
Ok(())
}

let mut builder_in1 = ListBuilder::new(PrimitiveArray::<Int64Type>::builder(0));
let mut builder_in2 = ListBuilder::new(PrimitiveArray::<Int64Type>::builder(0));
let mut builder_in3 = ListBuilder::new(PrimitiveArray::<Int64Type>::builder(0));
populate_list1(&mut builder_in1)?;
populate_list2(&mut builder_in2)?;
populate_list3(&mut builder_in3)?;

let mut builder_expected =
ListBuilder::new(PrimitiveArray::<Int64Type>::builder(0));
populate_list1(&mut builder_expected)?;
populate_list2(&mut builder_expected)?;
populate_list3(&mut builder_expected)?;

let array_result = concat(&[
Arc::new(builder_in1.finish()),
Arc::new(builder_in2.finish()),
Arc::new(builder_in3.finish()),
])?;

let array_expected = builder_expected.finish();

assert!(
array_result.equals(&array_expected),
"expect {:#?} to be: {:#?}",
array_result,
&array_expected
);

Ok(())
}
}
127 changes: 127 additions & 0 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,133 @@ mod tests {
Ok(())
}

async fn run_count_distinct_integers_aggregated_scenario(
partitions: Vec<Vec<(&str, u64)>>,
) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new()?;
let mut ctx = ExecutionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("c_group", DataType::Utf8, false),
Field::new("c_int8", DataType::Int8, false),
Field::new("c_int16", DataType::Int16, false),
Field::new("c_int32", DataType::Int32, false),
Field::new("c_int64", DataType::Int64, false),
Field::new("c_uint8", DataType::UInt8, false),
Field::new("c_uint16", DataType::UInt16, false),
Field::new("c_uint32", DataType::UInt32, false),
Field::new("c_uint64", DataType::UInt64, false),
]));

for (i, partition) in partitions.iter().enumerate() {
let filename = format!("partition-{}.csv", i);
let file_path = tmp_dir.path().join(&filename);
let mut file = File::create(file_path)?;
for row in partition {
let row_str = format!(
"{},{}\n",
row.0,
// Populate values for each of the integer fields in the
// schema.
(0..8)
.map(|_| { row.1.to_string() })
.collect::<Vec<_>>()
.join(","),
);
file.write_all(row_str.as_bytes())?;
}
}
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema).has_header(false),
)?;

let results = collect(
&mut ctx,
"
SELECT
c_group,
COUNT(c_uint64),
COUNT(DISTINCT c_int8),
COUNT(DISTINCT c_int16),
COUNT(DISTINCT c_int32),
COUNT(DISTINCT c_int64),
COUNT(DISTINCT c_uint8),
COUNT(DISTINCT c_uint16),
COUNT(DISTINCT c_uint32),
COUNT(DISTINCT c_uint64)
FROM test
GROUP BY c_group
",
)
.await?;

Ok(results)
}

#[tokio::test]
async fn count_distinct_integers_aggregated_single_partition() -> Result<()> {
let partitions = vec![
// The first member of each tuple will be the value for the
// `c_group` column, and the second member will be the value for
// each of the int/uint fields.
vec![
("a", 1),
("a", 1),
("a", 2),
("b", 9),
("c", 9),
("c", 10),
("c", 9),
],
];

let results = run_count_distinct_integers_aggregated_scenario(partitions).await?;
assert_eq!(results.len(), 1);

let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 10);
assert_eq!(
test::format_batch(&batch),
vec![
"a,3,2,2,2,2,2,2,2,2",
"c,3,2,2,2,2,2,2,2,2",
"b,1,1,1,1,1,1,1,1,1",
],
);

Ok(())
}

#[tokio::test]
async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> {
let partitions = vec![
// The first member of each tuple will be the value for the
// `c_group` column, and the second member will be the value for
// each of the int/uint fields.
vec![("a", 1), ("a", 1), ("a", 2), ("b", 9), ("c", 9)],
vec![("a", 1), ("a", 3), ("b", 8), ("b", 9), ("b", 10), ("b", 11)],
];

let results = run_count_distinct_integers_aggregated_scenario(partitions).await?;
assert_eq!(results.len(), 1);

let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 10);
assert_eq!(
test::format_batch(&batch),
vec![
"a,5,3,3,3,3,3,3,3,3",
"c,1,1,1,1,1,1,1,1,1",
"b,5,4,4,4,4,4,4,4,4",
],
);

Ok(())
}

#[test]
fn aggregate_with_alias() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
57 changes: 44 additions & 13 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,19 @@ pub use operators::Operator;

fn create_function_name(
fun: &String,
distinct: bool,
args: &[Expr],
input_schema: &Schema,
) -> Result<String> {
let names: Vec<String> = args
.iter()
.map(|e| create_name(e, input_schema))
.collect::<Result<_>>()?;
Ok(format!("{}({})", fun, names.join(",")))
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
};
Ok(format!("{}({}{})", fun, distinct_str, names.join(",")))
}

/// Returns a readable name of an expression based on the input schema.
Expand Down Expand Up @@ -90,14 +95,17 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
Ok(format!("{} IS NOT NULL", expr))
}
Expr::ScalarFunction { fun, args, .. } => {
create_function_name(&fun.to_string(), args, input_schema)
create_function_name(&fun.to_string(), false, args, input_schema)
}
Expr::ScalarUDF { fun, args, .. } => {
create_function_name(&fun.name, args, input_schema)
}
Expr::AggregateFunction { fun, args, .. } => {
create_function_name(&fun.to_string(), args, input_schema)
create_function_name(&fun.name, false, args, input_schema)
}
Expr::AggregateFunction {
fun,
distinct,
args,
..
} => create_function_name(&fun.to_string(), *distinct, args, input_schema),
Expr::AggregateUDF { fun, args } => {
let mut names = Vec::with_capacity(args.len());
for e in args {
Expand Down Expand Up @@ -195,6 +203,8 @@ pub enum Expr {
fun: aggregates::AggregateFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
distinct: bool,
},
/// aggregate function
AggregateUDF {
Expand Down Expand Up @@ -447,6 +457,7 @@ pub fn col(name: &str) -> Expr {
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Min,
distinct: false,
args: vec![expr],
}
}
Expand All @@ -455,6 +466,7 @@ pub fn min(expr: Expr) -> Expr {
pub fn max(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Max,
distinct: false,
args: vec![expr],
}
}
Expand All @@ -463,6 +475,7 @@ pub fn max(expr: Expr) -> Expr {
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Sum,
distinct: false,
args: vec![expr],
}
}
Expand All @@ -471,6 +484,7 @@ pub fn sum(expr: Expr) -> Expr {
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Avg,
distinct: false,
args: vec![expr],
}
}
Expand All @@ -479,6 +493,7 @@ pub fn avg(expr: Expr) -> Expr {
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Count,
distinct: false,
args: vec![expr],
}
}
Expand Down Expand Up @@ -620,9 +635,18 @@ pub fn create_udaf(
)
}

fn fmt_function(f: &mut fmt::Formatter, fun: &String, args: &Vec<Expr>) -> fmt::Result {
fn fmt_function(
f: &mut fmt::Formatter,
fun: &String,
distinct: bool,
args: &Vec<Expr>,
) -> fmt::Result {
let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect();
write!(f, "{}({})", fun, args.join(", "))
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
};
write!(f, "{}({}{})", fun, distinct_str, args.join(", "))
}

impl fmt::Debug for Expr {
Expand Down Expand Up @@ -658,13 +682,20 @@ impl fmt::Debug for Expr {
}
}
Expr::ScalarFunction { fun, args, .. } => {
fmt_function(f, &fun.to_string(), args)
fmt_function(f, &fun.to_string(), false, args)
}
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args)
}
Expr::ScalarUDF { fun, ref args, .. } => fmt_function(f, &fun.name, args),
Expr::AggregateFunction { fun, ref args, .. } => {
fmt_function(f, &fun.to_string(), args)
Expr::AggregateFunction {
fun,
distinct,
ref args,
..
} => fmt_function(f, &fun.to_string(), *distinct, args),
Expr::AggregateUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args)
}
Expr::AggregateUDF { fun, ref args, .. } => fmt_function(f, &fun.name, args),
Expr::Wildcard => write!(f, "*"),
Expr::Nested(expr) => write!(f, "({:?})", expr),
}
Expand Down
Loading