Skip to content

Commit

Permalink
feat: update first_value and last_value wrappers.
Browse files Browse the repository at this point in the history
Upstream signatures were changed for the new new `AggregateBuilder` api [0].

This simply gets the code to work. We should better incorporate that API into `datafusion-python`.

[0] apache/datafusion#10560
  • Loading branch information
Michael-J-Ward committed Jul 24, 2024
1 parent 1836085 commit d29f0ed
Showing 1 changed file with 61 additions and 27 deletions.
88 changes: 61 additions & 27 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion_expr::AggregateExt;
use pyo3::{prelude::*, wrap_pyfunction};

use crate::common::data_type::NullTreatment;
Expand Down Expand Up @@ -75,47 +76,80 @@ pub fn var(y: PyExpr) -> PyExpr {
}

#[pyfunction]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn first_value(
args: Vec<PyExpr>,
expr: PyExpr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyExpr {
let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
) -> PyResult<PyExpr> {
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::first_value(
args,
distinct,
filter.map(|x| Box::new(x.expr)),
order_by,
null_treatment,
)
.into()

// TODO: add `builder()` to `AggregateExt` to avoid this boilerplate
let builder = functions_aggregate::expr_fn::first_value(expr.expr, order_by);

let builder = if let Some(filter) = filter {
let filter = filter.expr;
builder.filter(filter).build()?
} else {
builder
};

let builder = if distinct {
builder.distinct().build()?
} else {
builder
};

let builder = if let Some(null_treatment) = null_treatment {
builder.null_treatment(null_treatment.into()).build()?
} else {
builder
};

Ok(builder.into())
}

#[pyfunction]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn last_value(
args: Vec<PyExpr>,
expr: PyExpr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyExpr {
let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::last_value(
args,
distinct,
filter.map(|x| Box::new(x.expr)),
order_by,
null_treatment,
)
.into()
) -> PyResult<PyExpr> {
// TODO: add `builder()` to `AggregateExt` to avoid this boilerplate
let builder = functions_aggregate::expr_fn::last_value(vec![expr.expr]);

let builder = if distinct {
builder.distinct().build()?
} else {
builder
};

let builder = if let Some(filter) = filter {
let filter = filter.expr;
builder.filter(filter).build()?
} else {
builder
};

let builder = if let Some(order_by) = order_by {
let order_by = order_by.into_iter().map(|x| x.expr).collect::<Vec<_>>();
builder.order_by(order_by).build()?
} else {
builder
};

let builder = if let Some(null_treatment) = null_treatment {
builder.null_treatment(null_treatment.into()).build()?
} else {
builder
};

Ok(builder.into())
}

#[pyfunction]
Expand Down

0 comments on commit d29f0ed

Please sign in to comment.