Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,13 @@ pub async fn from_substrait_agg_func(

let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;

// deal with situation that count(*) got no arguments

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just explain in the comment why we need count() to have arguments?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty much just rolled back whatever was removed in #14824, I would like some input first from @jayzhan211 about why was this removed and if it might be breaking something else that the tests are not catching

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remove it because I don't why we need to convert count() to count(1), is there any reason we need count(1) in subtrait?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is reason, it would be nice to document the reason besides the conversion

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is that DataFusion does not support aggregate functions with no arguments:

pub fn build(self) -> Result<AggregateFunctionExpr> {
let Self {
fun,
args,
alias,
human_display,
schema,
ordering_req,
ignore_nulls,
is_distinct,
is_reversed,
} = self;
if args.is_empty() {
return internal_err!("args should not be empty");
}

I imagine that this is the reason why it was there before.

If you do a similar operation and analyze DataFusion's logical plan (link), a similar Int64(1) argument is injected.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment explaining that

let args = if udaf.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
args
};

Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
)))
Expand Down Expand Up @@ -2248,11 +2255,16 @@ pub async fn from_window_function(

window_frame.regularize_order_bys(&mut order_by)?;

let args = if fun.name() == "count" && window.arguments.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
from_substrait_func_args(consumer, &window.arguments, input_schema).await?
};

Ok(Expr::WindowFunction(expr::WindowFunction {
fun,
params: WindowFunctionParams {
args: from_substrait_func_args(consumer, &window.arguments, input_schema)
.await?,
args,
partition_by: from_substrait_rex_vec(
consumer,
&window.partitions,
Expand Down Expand Up @@ -3406,4 +3418,31 @@ mod test {

Ok(())
}

#[tokio::test]
async fn window_function_with_count() -> Result<()> {
let substrait = substrait::proto::Expression {
rex_type: Some(substrait::proto::expression::RexType::WindowFunction(
substrait::proto::expression::WindowFunction {
function_reference: 0,
..Default::default()
},
)),
};

let mut consumer = test_consumer();

let mut extensions = Extensions::default();
extensions.register_function("count".to_string());
consumer.extensions = &extensions;

match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
Expr::WindowFunction(window_function) => {
assert_eq!(window_function.params.args.len(), 1)
}
_ => panic!("expr was not a WindowFunction"),
};

Ok(())
}
}
66 changes: 53 additions & 13 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mod tests {

let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
ctx.state().create_physical_plan(&plan).await?;
Ok(format!("{}", plan))
}

Expand All @@ -50,9 +51,9 @@ mod tests {
let plan_str = tpch_plan_to_string(1).await?;
assert_eq!(
plan_str,
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count() AS COUNT_ORDER\
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\
\n Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST\
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count()]]\
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\
\n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\
\n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\
\n TableScan: LINEITEM"
Expand Down Expand Up @@ -119,9 +120,9 @@ mod tests {
let plan_str = tpch_plan_to_string(4).await?;
assert_eq!(
plan_str,
"Projection: ORDERS.O_ORDERPRIORITY, count() AS ORDER_COUNT\
"Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\
\n Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST\
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count()]]\
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]]\
\n Projection: ORDERS.O_ORDERPRIORITY\
\n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS (<subquery>)\
\n Subquery:\
Expand Down Expand Up @@ -269,10 +270,10 @@ mod tests {
let plan_str = tpch_plan_to_string(13).await?;
assert_eq!(
plan_str,
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count() AS CUSTDIST\
\n Sort: count() DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
\n Projection: count(ORDERS.O_ORDERKEY), count()\
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count()]]\
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST\
\n Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
\n Projection: count(ORDERS.O_ORDERKEY), count(Int64(1))\
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]]\
\n Projection: count(ORDERS.O_ORDERKEY)\
\n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\
\n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\
Expand Down Expand Up @@ -410,10 +411,10 @@ mod tests {
let plan_str = tpch_plan_to_string(21).await?;
assert_eq!(
plan_str,
"Projection: SUPPLIER.S_NAME, count() AS NUMWAIT\
"Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT\
\n Limit: skip=0, fetch=100\
\n Sort: count() DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count()]]\
\n Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]]\
\n Projection: SUPPLIER.S_NAME\
\n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS (<subquery>) AND NOT EXISTS (<subquery>) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\
\n Subquery:\
Expand All @@ -438,9 +439,9 @@ mod tests {
let plan_str = tpch_plan_to_string(22).await?;
assert_eq!(
plan_str,
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count() AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
\n Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST\
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(), sum(CUSTOMER.C_ACCTBAL)]]\
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]]\
\n Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL\
\n Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND CUSTOMER.C_ACCTBAL > (<subquery>) AND NOT EXISTS (<subquery>)\
\n Subquery:\
Expand All @@ -455,4 +456,43 @@ mod tests {
);
Ok(())
}

async fn test_plan_to_string(name: &str) -> Result<String> {
let path = format!("tests/testdata/test_plans/{name}");
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
ctx.state().create_physical_plan(&plan).await?;
Ok(format!("{}", plan))
}

#[tokio::test]
async fn test_select_count_from_select_1() -> Result<()> {
let plan_str =
test_plan_to_string("select_count_from_select_1.substrait.json").await?;

assert_eq!(
plan_str,
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
\n Values: (Int64(0))"
);
Ok(())
}

#[tokio::test]
async fn test_select_window_count() -> Result<()> {
let plan_str = test_plan_to_string("select_window_count.substrait.json").await?;

assert_eq!(
plan_str,
"Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA"
);
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"extensionUris": [
{
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"functionAnchor": 185,
"name": "count:any"
}
}
],
"relations": [
{
"root": {
"input": {
"aggregate": {
"common": {
"direct": {
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": [
"dummy"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"virtualTable": {
"values": [
{
"fields": [
{
"i64": "0",
"nullable": false
}
]
}
]
}
}
},
"groupings": [
{
"groupingExpressions": [],
"expressionReferences": []
}
],
"measures": [
{
"measure": {
"functionReference": 185,
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL",
"arguments": [],
"options": []
}
}
],
"groupingExpressions": []
}
},
"names": [
"count(*)"
]
}
}
]
}
Loading