Skip to content

Commit 9e3d614

Browse files
authored
perf: use is_sorted in ewm_mean_by, deprecate check_sorted (pola-rs#16335)
1 parent 7cf55e0 commit 9e3d614

File tree

9 files changed

+79
-61
lines changed

9 files changed

+79
-61
lines changed

Diff for: crates/polars-ops/src/series/ops/ewm_by.rs

+45-19
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,68 @@ pub fn ewm_mean_by(
77
s: &Series,
88
times: &Series,
99
half_life: i64,
10-
assume_sorted: bool,
10+
times_is_sorted: bool,
1111
) -> PolarsResult<Series> {
12-
match (s.dtype(), times.dtype()) {
13-
(DataType::Float64, DataType::Int64) => Ok((if assume_sorted {
14-
ewm_mean_by_impl_sorted(s.f64().unwrap(), times.i64().unwrap(), half_life)
15-
} else {
16-
ewm_mean_by_impl(s.f64().unwrap(), times.i64().unwrap(), half_life)
17-
})
18-
.into_series()),
19-
(DataType::Float32, DataType::Int64) => Ok((if assume_sorted {
20-
ewm_mean_by_impl_sorted(s.f32().unwrap(), times.i64().unwrap(), half_life)
12+
fn func<T>(
13+
values: &ChunkedArray<T>,
14+
times: &Int64Chunked,
15+
half_life: i64,
16+
times_is_sorted: bool,
17+
) -> PolarsResult<Series>
18+
where
19+
T: PolarsFloatType,
20+
T::Native: Float + Zero + One,
21+
ChunkedArray<T>: IntoSeries,
22+
{
23+
if times_is_sorted {
24+
Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())
2125
} else {
22-
ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life)
23-
})
24-
.into_series()),
26+
Ok(ewm_mean_by_impl(values, times, half_life).into_series())
27+
}
28+
}
29+
30+
match (s.dtype(), times.dtype()) {
31+
(DataType::Float64, DataType::Int64) => func(
32+
s.f64().unwrap(),
33+
times.i64().unwrap(),
34+
half_life,
35+
times_is_sorted,
36+
),
37+
(DataType::Float32, DataType::Int64) => func(
38+
s.f32().unwrap(),
39+
times.i64().unwrap(),
40+
half_life,
41+
times_is_sorted,
42+
),
2543
#[cfg(feature = "dtype-datetime")]
2644
(_, DataType::Datetime(time_unit, _)) => {
2745
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
28-
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
46+
ewm_mean_by(
47+
s,
48+
&times.cast(&DataType::Int64)?,
49+
half_life,
50+
times_is_sorted,
51+
)
2952
},
3053
#[cfg(feature = "dtype-date")]
3154
(_, DataType::Date) => ewm_mean_by(
3255
s,
3356
&times.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
3457
half_life,
35-
assume_sorted,
58+
times_is_sorted,
59+
),
60+
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(
61+
s,
62+
&times.cast(&DataType::Int64)?,
63+
half_life,
64+
times_is_sorted,
3665
),
37-
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => {
38-
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
39-
},
4066
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
4167
ewm_mean_by(
4268
&s.cast(&DataType::Float64)?,
4369
times,
4470
half_life,
45-
assume_sorted,
71+
times_is_sorted,
4672
)
4773
},
4874
_ => {

Diff for: crates/polars-plan/src/dsl/function_expr/ewm_by.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
use polars_ops::series::SeriesMethods;
2+
13
use super::*;
24

3-
pub(super) fn ewm_mean_by(
4-
s: &[Series],
5-
half_life: Duration,
6-
check_sorted: bool,
7-
) -> PolarsResult<Series> {
5+
pub(super) fn ewm_mean_by(s: &[Series], half_life: Duration) -> PolarsResult<Series> {
86
let time_zone = match s[1].dtype() {
97
DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()),
108
_ => None,
@@ -15,6 +13,6 @@ pub(super) fn ewm_mean_by(
1513
let half_life = half_life.duration_ns();
1614
let values = &s[0];
1715
let times = &s[1];
18-
let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending;
19-
polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted)
16+
let times_is_sorted = times.is_sorted(Default::default())?;
17+
polars_ops::prelude::ewm_mean_by(values, times, half_life, times_is_sorted)
2018
}

Diff for: crates/polars-plan/src/dsl/function_expr/mod.rs

+2-9
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ pub enum FunctionExpr {
328328
#[cfg(feature = "ewma_by")]
329329
EwmMeanBy {
330330
half_life: Duration,
331-
check_sorted: bool,
332331
},
333332
#[cfg(feature = "ewma")]
334333
EwmStd {
@@ -542,10 +541,7 @@ impl Hash for FunctionExpr {
542541
#[cfg(feature = "ewma")]
543542
EwmMean { options } => options.hash(state),
544543
#[cfg(feature = "ewma_by")]
545-
EwmMeanBy {
546-
half_life,
547-
check_sorted,
548-
} => (half_life, check_sorted).hash(state),
544+
EwmMeanBy { half_life } => (half_life).hash(state),
549545
#[cfg(feature = "ewma")]
550546
EwmStd { options } => options.hash(state),
551547
#[cfg(feature = "ewma")]
@@ -1118,10 +1114,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
11181114
#[cfg(feature = "ewma")]
11191115
EwmMean { options } => map!(ewm::ewm_mean, options),
11201116
#[cfg(feature = "ewma_by")]
1121-
EwmMeanBy {
1122-
half_life,
1123-
check_sorted,
1124-
} => map_as_slice!(ewm_by::ewm_mean_by, half_life, check_sorted),
1117+
EwmMeanBy { half_life } => map_as_slice!(ewm_by::ewm_mean_by, half_life),
11251118
#[cfg(feature = "ewma")]
11261119
EwmStd { options } => map!(ewm::ewm_std, options),
11271120
#[cfg(feature = "ewma")]

Diff for: crates/polars-plan/src/dsl/mod.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -1647,12 +1647,9 @@ impl Expr {
16471647

16481648
#[cfg(feature = "ewma_by")]
16491649
/// Calculate the exponentially-weighted moving average by a time column.
1650-
pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self {
1650+
pub fn ewm_mean_by(self, times: Expr, half_life: Duration) -> Self {
16511651
self.apply_many_private(
1652-
FunctionExpr::EwmMeanBy {
1653-
half_life,
1654-
check_sorted,
1655-
},
1652+
FunctionExpr::EwmMeanBy { half_life },
16561653
&[times],
16571654
false,
16581655
false,

Diff for: py-polars/polars/expr/expr.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -10537,7 +10537,7 @@ def ewm_mean_by(
1053710537
by: str | IntoExpr,
1053810538
*,
1053910539
half_life: str | timedelta,
10540-
check_sorted: bool = True,
10540+
check_sorted: bool | None = None,
1054110541
) -> Self:
1054210542
r"""
1054310543
Calculate time-based exponentially weighted moving average.
@@ -10587,6 +10587,10 @@ def ewm_mean_by(
1058710587
Check whether `by` column is sorted.
1058810588
Incorrectly setting this to `False` will lead to incorrect output.
1058910589
10590+
.. deprecated:: 0.20.27
10591+
Sortedness is now verified in a quick manner, you can safely remove
10592+
this argument.
10593+
1059010594
Returns
1059110595
-------
1059210596
Expr
@@ -10625,7 +10629,12 @@ def ewm_mean_by(
1062510629
"""
1062610630
by = parse_as_expression(by)
1062710631
half_life = parse_as_duration_string(half_life)
10628-
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life, check_sorted))
10632+
if check_sorted is not None:
10633+
issue_deprecation_warning(
10634+
"`check_sorted` is now deprecated in `ewm_mean_by`, you can safely remove this argument.",
10635+
version="0.20.27",
10636+
)
10637+
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life))
1062910638

1063010639
@deprecate_nonkeyword_arguments(version="0.19.10")
1063110640
def ewm_std(

Diff for: py-polars/src/expr/general.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -858,11 +858,11 @@ impl PyExpr {
858858
};
859859
self.inner.clone().ewm_mean(options).into()
860860
}
861-
fn ewm_mean_by(&self, times: PyExpr, half_life: &str, check_sorted: bool) -> Self {
861+
fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> Self {
862862
let half_life = Duration::parse(half_life);
863863
self.inner
864864
.clone()
865-
.ewm_mean_by(times.inner, half_life, check_sorted)
865+
.ewm_mean_by(times.inner, half_life)
866866
.into()
867867
}
868868

Diff for: py-polars/src/lazyframe/visitor/expr_nodes.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -1020,10 +1020,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
10201020
FunctionExpr::TopKBy { sort_options: _ } => {
10211021
return Err(PyNotImplementedError::new_err("top_k_by"))
10221022
},
1023-
FunctionExpr::EwmMeanBy {
1024-
half_life: _,
1025-
check_sorted: _,
1026-
} => return Err(PyNotImplementedError::new_err("ewm_mean_by")),
1023+
FunctionExpr::EwmMeanBy { half_life: _ } => {
1024+
return Err(PyNotImplementedError::new_err("ewm_mean_by"))
1025+
},
10271026
},
10281027
options: py.None(),
10291028
}

Diff for: py-polars/tests/unit/functions/test_ewm_by.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ def test_ewm_by(data: st.DataObject, half_life: int) -> None:
2727
)
2828
)
2929
result = df.with_row_index().select(
30-
pl.col("values").ewm_mean_by(
31-
by="index", half_life=f"{half_life}i", check_sorted=False
32-
)
30+
pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i")
3331
)
3432
expected = df.select(
3533
pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False)

Diff for: py-polars/tests/unit/operations/test_ewm_by.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -173,22 +173,20 @@ def test_ewma_by_empty() -> None:
173173
assert_frame_equal(result, expected)
174174

175175

176-
def test_ewma_by_warn_if_unsorted() -> None:
176+
def test_ewma_by_if_unsorted() -> None:
177177
df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})
178-
179-
# Check that with `check_sorted=False`, the user can get incorrect results
180-
# if they really want to.
181-
result = df.select(
182-
pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False),
183-
)
184-
expected = pl.DataFrame({"values": [3.0, 4.0]})
185-
assert_frame_equal(result, expected)
186-
187178
result = df.with_columns(
188179
pl.col("values").ewm_mean_by("by", half_life="2i"),
189180
)
190181
expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]})
191182
assert_frame_equal(result, expected)
183+
184+
with pytest.deprecated_call(match="you can safely remove this argument"):
185+
result = df.with_columns(
186+
pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False),
187+
)
188+
assert_frame_equal(result, expected)
189+
192190
result = df.sort("by").with_columns(
193191
pl.col("values").ewm_mean_by("by", half_life="2i"),
194192
)

0 commit comments

Comments
 (0)