Skip to content

Commit

Permalink
migrate regr_* functions to UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-J-Ward committed Jul 25, 2024
1 parent 40d9d3e commit 86d9d9b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 18 deletions.
18 changes: 9 additions & 9 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,50 +1398,50 @@ def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
Only non-null pairs of the inputs are evaluated.
"""
return Expr(f.regr_avgx[y.expr, x.expr], distinct)
return Expr(f.regr_avgx(y.expr, x.expr, distinct))


def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the average of the dependent variable ``y``.
Only non-null pairs of the inputs are evaluated.
"""
return Expr(f.regr_avgy[y.expr, x.expr], distinct)
return Expr(f.regr_avgy(y.expr, x.expr, distinct))


def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Counts the number of rows in which both expressions are not null."""
return Expr(f.regr_count[y.expr, x.expr], distinct)
return Expr(f.regr_count(y.expr, x.expr, distinct))


def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the intercept from the linear regression."""
return Expr(f.regr_intercept[y.expr, x.expr], distinct)
return Expr(f.regr_intercept(y.expr, x.expr, distinct))


def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the R-squared value from linear regression."""
return Expr(f.regr_r2[y.expr, x.expr], distinct)
return Expr(f.regr_r2(y.expr, x.expr, distinct))


def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the slope from linear regression."""
return Expr(f.regr_slope[y.expr, x.expr], distinct)
return Expr(f.regr_slope(y.expr, x.expr, distinct))


def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the sum of squares of the independent variable `x`."""
return Expr(f.regr_sxx[y.expr, x.expr], distinct)
return Expr(f.regr_sxx(y.expr, x.expr, distinct))


def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the sum of products of pairs of numbers."""
return Expr(f.regr_sxy[y.expr, x.expr], distinct)
return Expr(f.regr_sxy(y.expr, x.expr, distinct))


def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
"""Computes the sum of squares of the dependent variable `y`."""
return Expr(f.regr_syy[y.expr, x.expr], distinct)
return Expr(f.regr_syy(y.expr, x.expr, distinct))


def first_value(
Expand Down
99 changes: 90 additions & 9 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,96 @@ pub fn var_pop(expression: PyExpr, distinct: bool) -> PyResult<PyExpr> {
}
}

#[pyfunction]
pub fn regr_avgx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_avgx(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_avgy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_avgy(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_count(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_count(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_intercept(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_intercept(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_r2(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_r2(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_slope(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_slope(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_sxx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_sxx(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_sxy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_sxy(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
let expr = functions_aggregate::expr_fn::regr_syy(expr_y.expr, expr_x.expr);
if distinct {
Ok(expr.distinct().build()?.into())
} else {
Ok(expr.into())
}
}

#[pyfunction]
#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn first_value(
Expand Down Expand Up @@ -847,15 +937,6 @@ array_fn!(range, start stop step);
aggregate_function!(array_agg, ArrayAgg);
aggregate_function!(max, Max);
aggregate_function!(min, Min);
aggregate_function!(regr_avgx, RegrAvgx);
aggregate_function!(regr_avgy, RegrAvgy);
aggregate_function!(regr_count, RegrCount);
aggregate_function!(regr_intercept, RegrIntercept);
aggregate_function!(regr_r2, RegrR2);
aggregate_function!(regr_slope, RegrSlope);
aggregate_function!(regr_sxx, RegrSXX);
aggregate_function!(regr_sxy, RegrSXY);
aggregate_function!(regr_syy, RegrSYY);
aggregate_function!(bit_and, BitAnd);
aggregate_function!(bit_or, BitOr);
aggregate_function!(bit_xor, BitXor);
Expand Down

0 comments on commit 86d9d9b

Please sign in to comment.