Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dask_planner/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ impl Dialect for DaskDialect {
fn is_proper_identifier_inside_quotes(&self, mut _chars: Peekable<Chars<'_>>) -> bool {
true
}
/// Determine if FILTER (WHERE ...) filters are allowed during aggregations
fn supports_filter_during_aggregation(&self) -> bool {
true
}
}
17 changes: 17 additions & 0 deletions dask_planner/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,23 @@ impl PyExpr {
}))
}

#[pyo3(name = "getFilterExpr")]
pub fn get_filter_expr(&self) -> PyResult<Option<PyExpr>> {
match &self.expr {
Expr::AggregateFunction { filter, .. } | Expr::AggregateUDF { filter, .. } => {
match filter {
Some(filter) => {
Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone())))
}
None => Ok(None),
}
}
_ => Err(py_type_err(
"getFilterExpr() - Non-aggregate expression encountered",
)),
}
}

/// TODO: I can't express how much I dislike explicity listing all of these methods out
/// but PyO3 makes it necessary since its annotations cannot be used in trait impl blocks
#[pyo3(name = "getFloat32Value")]
Expand Down
61 changes: 32 additions & 29 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,39 +333,39 @@ def _collect_aggregations(

collected_aggregations = defaultdict(list)

# convert and assign any input/filter columns that don't currently exist
new_columns = {}
new_mappings = {}

# convert and assign any input columns that don't currently exist
for expr in agg.getNamedAggCalls():
for expr in agg.getArgs(expr):
key = expr.column_name(input_rel)
if key in cc._frontend_backend_mapping:
continue
random_name = new_temporary_column(df)
new_columns[random_name] = RexConverter.convert(
input_rel, expr, dc, context=context
)
new_mappings[key] = random_name

if new_columns:
df = df.assign(**new_columns)

for key, backend_column_name in new_mappings.items():
cc = cc.add(key, backend_column_name)

for expr in agg.getNamedAggCalls():
# Determine the aggregation function to use
assert expr.getExprType() in {
"Alias",
"AggregateFunction",
"AggregateUDF",
}, "Do not know how to handle this case!"
for input_expr in agg.getArgs(expr):
input_col = input_expr.column_name(input_rel)
if input_col not in cc._frontend_backend_mapping:
random_name = new_temporary_column(df)
new_columns[random_name] = RexConverter.convert(
input_rel, input_expr, dc, context=context
)
cc = cc.add(input_col, random_name)
filter_expr = expr.getFilterExpr()
if filter_expr is not None:
filter_col = filter_expr.column_name(input_rel)
if filter_col not in cc._frontend_backend_mapping:
random_name = new_temporary_column(df)
new_columns[random_name] = RexConverter.convert(
input_rel, filter_expr, dc, context=context
)
cc = cc.add(filter_col, random_name)
if new_columns:
df = df.assign(**new_columns)

for expr in agg.getNamedAggCalls():
schema_name = context.schema_name
aggregation_name = agg.getAggregationFuncName(expr).lower()

# Gather information about the input column
# Gather information about input columns
inputs = agg.getArgs(expr)

# TODO: This if statement is likely no longer needed but left here for the time being just in case
Expand Down Expand Up @@ -397,10 +397,13 @@ def _collect_aggregations(
else:
raise NotImplementedError("Can not cope with more than one input")

# TODO: DataFusion does not yet have the concept of "filters" in aggregations
filter_column = None
# if expr.hasFilter():
# filter_column = cc.get_backend_by_frontend_index(expr.filterArg)
filter_expr = expr.getFilterExpr()
if filter_expr is not None:
filter_backend_col = cc.get_backend_by_frontend_name(
filter_expr.column_name(input_rel)
)
else:
filter_backend_col = None

try:
aggregation_function = self.AGGREGATION_MAPPING[aggregation_name]
Expand All @@ -423,9 +426,9 @@ def _collect_aggregations(
output_col = expr.toString()

# Store the aggregation
key = filter_column
value = (input_col, output_col, aggregation_function)
collected_aggregations[key].append(value)
collected_aggregations[filter_backend_col].append(
(input_col, output_col, aggregation_function)
)
output_column_order.append(output_col)

return collected_aggregations, output_column_order, df, cc
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def test_group_by_all(c, df):
assert_eq(result_df, expected_df)


@pytest.mark.skip(
reason="WIP DataFusion - https://github.com/dask-contrib/dask-sql/issues/463"
)
def test_group_by_filtered(c):
return_df = c.sql(
"""
Expand Down