Skip to content

Commit

Permalink
refactor: Make pl.repeat part of the IR (#19152)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Oct 9, 2024
1 parent 688e9b4 commit 2e0438b
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 20 deletions.
17 changes: 10 additions & 7 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,18 @@ impl Column {

#[inline]
pub fn new_from_index(&self, index: usize, length: usize) -> Self {
if index >= self.len() {
return Self::full_null(self.name().clone(), length, self.dtype());
}

match self {
Column::Series(s) => s.new_from_index(index, length).into(),
Column::Scalar(s) => {
if index >= s.len() {
Self::full_null(s.name().clone(), length, s.dtype())
} else {
s.resize(length).into()
}
Column::Series(s) => {
// SAFETY: Bounds check done before.
let av = unsafe { s.get_unchecked(index) };
let scalar = Scalar::new(self.dtype().clone(), av.into_static());
Self::new_scalar(self.name().clone(), scalar, length)
},
Column::Scalar(s) => s.resize(length).into(),
}
}

Expand Down
5 changes: 5 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod pow;
mod random;
#[cfg(feature = "range")]
mod range;
mod repeat;
#[cfg(feature = "rolling_window")]
pub mod rolling;
#[cfg(feature = "rolling_window_by")]
Expand Down Expand Up @@ -189,6 +190,7 @@ pub enum FunctionExpr {
options: RankOptions,
seed: Option<u64>,
},
Repeat,
#[cfg(feature = "round_series")]
Clip {
has_min: bool,
Expand Down Expand Up @@ -452,6 +454,7 @@ impl Hash for FunctionExpr {
a.hash(state);
b.hash(state);
},
Repeat => {},
#[cfg(feature = "rank")]
Rank { options, seed } => {
options.hash(state);
Expand Down Expand Up @@ -651,6 +654,7 @@ impl Display for FunctionExpr {
#[cfg(feature = "moment")]
Kurtosis(..) => "kurtosis",
ArgUnique => "arg_unique",
Repeat => "repeat",
#[cfg(feature = "rank")]
Rank { .. } => "rank",
#[cfg(feature = "round_series")]
Expand Down Expand Up @@ -996,6 +1000,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "moment")]
Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias),
ArgUnique => map!(dispatch::arg_unique),
Repeat => map_as_slice!(repeat::repeat),
#[cfg(feature = "rank")]
Rank { options, seed } => map!(dispatch::rank, options, seed),
#[cfg(feature = "dtype-struct")]
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use polars_core::prelude::{polars_ensure, polars_err, Column, PolarsResult};

pub fn repeat(args: &[Column]) -> PolarsResult<Column> {
let c = &args[0];
let n = &args[1];

polars_ensure!(
n.dtype().is_integer(),
SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype()
);

let first_value = n.get(0)?;
let n = first_value.extract::<usize>().ok_or_else(
|| polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value),
)?;

Ok(c.new_from_index(0, n))
}
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl FunctionExpr {
#[cfg(feature = "moment")]
Kurtosis(..) => mapper.with_dtype(DataType::Float64),
ArgUnique => mapper.with_dtype(IDX_DTYPE),
Repeat => mapper.with_same_dtype(),
#[cfg(feature = "rank")]
Rank { options, .. } => mapper.with_dtype(match options.method {
RankMethod::Average => DataType::Float64,
Expand Down
27 changes: 15 additions & 12 deletions crates/polars-plan/src/dsl/functions/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@ use super::*;
/// Generally you won't need this function, as `lit(value)` already represents a column containing
/// only `value` whose length is automatically set to the correct number of rows.
pub fn repeat<E: Into<Expr>>(value: E, n: Expr) -> Expr {
let function = |s: Column, n: Column| {
polars_ensure!(
n.dtype().is_integer(),
SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype()
);
let first_value = n.get(0)?;
let n = first_value.extract::<usize>().ok_or_else(
|| polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value),
)?;
Ok(Some(s.new_from_index(0, n)))
let input = vec![value.into(), n];

let expr = Expr::Function {
input,
function: FunctionExpr::Repeat,
options: FunctionOptions {
flags: FunctionFlags::default()
| FunctionFlags::ALLOW_RENAME
| FunctionFlags::CHANGES_LENGTH,
..Default::default()
},
};
apply_binary(value.into(), n, function, GetOutput::same_type())
.alias(PlSmallStr::from_static("repeat"))

// @NOTE: This alias should probably not be here for consistency, but it is here for backwards
// compatibility until 2.0.
expr.alias(PlSmallStr::from_static("repeat"))
}
2 changes: 1 addition & 1 deletion crates/polars-python/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl NodeTraverser {
// Increment major on breaking changes to the IR (e.g. renaming
// fields, reordering tuples), minor on backwards compatible
// changes (e.g. exposing a new expression node).
const VERSION: Version = (2, 2);
const VERSION: Version = (2, 3);

pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
Self {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
#[cfg(feature = "repeat_by")]
FunctionExpr::RepeatBy => ("repeat_by",).to_object(py),
FunctionExpr::ArgUnique => ("arg_unique",).to_object(py),
FunctionExpr::Repeat => ("repeat",).to_object(py),
FunctionExpr::Rank {
options: _,
seed: _,
Expand Down

0 comments on commit 2e0438b

Please sign in to comment.