Skip to content

Commit 234810d

Browse files
authored
fix: Fix decimal arithmetic schema (pola-rs#20398)
1 parent 15b8981 commit 234810d

File tree

4 files changed

+69
-5
lines changed

4 files changed

+69
-5
lines changed

Diff for: crates/polars-core/src/chunked_array/arithmetic/decimal.rs

+18-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ impl Add for &DecimalChunked {
44
type Output = PolarsResult<DecimalChunked>;
55

66
fn add(self, rhs: Self) -> Self::Output {
7-
let scale = self.scale().max(rhs.scale());
7+
let scale = _get_decimal_scale_add_sub(self.scale(), rhs.scale());
88
let lhs = self.to_scale(scale)?;
99
let rhs = rhs.to_scale(scale)?;
1010
Ok((&lhs.0 + &rhs.0).into_decimal_unchecked(None, scale))
@@ -15,7 +15,7 @@ impl Sub for &DecimalChunked {
1515
type Output = PolarsResult<DecimalChunked>;
1616

1717
fn sub(self, rhs: Self) -> Self::Output {
18-
let scale = self.scale().max(rhs.scale());
18+
let scale = _get_decimal_scale_add_sub(self.scale(), rhs.scale());
1919
let lhs = self.to_scale(scale)?;
2020
let rhs = rhs.to_scale(scale)?;
2121
Ok((&lhs.0 - &rhs.0).into_decimal_unchecked(None, scale))
@@ -26,7 +26,7 @@ impl Mul for &DecimalChunked {
2626
type Output = PolarsResult<DecimalChunked>;
2727

2828
fn mul(self, rhs: Self) -> Self::Output {
29-
let scale = self.scale() + rhs.scale();
29+
let scale = _get_decimal_scale_mul(self.scale(), rhs.scale());
3030
Ok((&self.0 * &rhs.0).into_decimal_unchecked(None, scale))
3131
}
3232
}
@@ -35,9 +35,22 @@ impl Div for &DecimalChunked {
3535
type Output = PolarsResult<DecimalChunked>;
3636

3737
fn div(self, rhs: Self) -> Self::Output {
38-
// Follow postgres and MySQL adding a fixed scale increment of 4
39-
let scale = self.scale() + 4;
38+
let scale = _get_decimal_scale_div(self.scale());
4039
let lhs = self.to_scale(scale + rhs.scale())?;
4140
Ok((&lhs.0 / &rhs.0).into_decimal_unchecked(None, scale))
4241
}
4342
}
43+
44+
// Used by polars-plan to determine schema.
45+
pub fn _get_decimal_scale_add_sub(scale_left: usize, scale_right: usize) -> usize {
46+
scale_left.max(scale_right)
47+
}
48+
49+
pub fn _get_decimal_scale_mul(scale_left: usize, scale_right: usize) -> usize {
50+
scale_left + scale_right
51+
}
52+
53+
pub fn _get_decimal_scale_div(scale_left: usize) -> usize {
54+
// Follow postgres and MySQL adding a fixed scale increment of 4
55+
scale_left + 4
56+
}

Diff for: crates/polars-core/src/chunked_array/arithmetic/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mod numeric;
66
use std::ops::{Add, Div, Mul, Rem, Sub};
77

88
use arrow::compute::utils::combine_validities_and;
9+
#[cfg(feature = "dtype-decimal")]
10+
pub use decimal::{_get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul};
911
use num_traits::{Num, NumCast, ToPrimitive};
1012
pub use numeric::ArithmeticChunked;
1113

Diff for: crates/polars-plan/src/plans/aexpr/schema.rs

+36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#[cfg(feature = "dtype-decimal")]
2+
use polars_core::chunked_array::arithmetic::{
3+
_get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul,
4+
};
15
use recursive::recursive;
26

37
use super::*;
@@ -500,6 +504,11 @@ fn get_arithmetic_field(
500504
other_dtype.leaf_dtype(),
501505
)?)
502506
},
507+
#[cfg(feature = "dtype-decimal")]
508+
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
509+
let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
510+
Decimal(None, Some(scale))
511+
},
503512
(left, right) => try_get_supertype(left, right)?,
504513
}
505514
},
@@ -549,6 +558,11 @@ fn get_arithmetic_field(
549558
other_dtype.leaf_dtype(),
550559
)?)
551560
},
561+
#[cfg(feature = "dtype-decimal")]
562+
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
563+
let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
564+
Decimal(None, Some(scale))
565+
},
552566
(left, right) => try_get_supertype(left, right)?,
553567
}
554568
},
@@ -581,6 +595,23 @@ fn get_arithmetic_field(
581595
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
582596
},
583597
},
598+
#[cfg(feature = "dtype-decimal")]
599+
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
600+
let scale = match op {
601+
Operator::Multiply => _get_decimal_scale_mul(*scale_left, *scale_right),
602+
Operator::Divide | Operator::TrueDivide => {
603+
_get_decimal_scale_div(*scale_left)
604+
},
605+
_ => {
606+
debug_assert!(false);
607+
*scale_left
608+
},
609+
};
610+
let dtype = Decimal(None, Some(scale));
611+
left_field.coerce(dtype);
612+
return Ok(left_field);
613+
},
614+
584615
(l @ List(a), r @ List(b))
585616
if ![a, b]
586617
.into_iter()
@@ -684,6 +715,11 @@ fn get_truediv_field(
684715
})
685716
},
686717
(Float32, _) => Float32,
718+
#[cfg(feature = "dtype-decimal")]
719+
(Decimal(_, Some(scale_left)), Decimal(_, _)) => {
720+
let scale = _get_decimal_scale_div(*scale_left);
721+
Decimal(None, Some(scale))
722+
},
687723
(dt, _) if dt.is_numeric() => Float64,
688724
#[cfg(feature = "dtype-duration")]
689725
(Duration(_), Duration(_)) => Float64,

Diff for: py-polars/tests/unit/datatypes/test_decimal.py

+13
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,16 @@ def test_decimal_round() -> None:
540540
expected_s = pl.Series("a", [round(v, decimals) for v in values], dtype)
541541

542542
assert_series_equal(got_s, expected_s)
543+
544+
545+
def test_decimal_arithmetic_schema() -> None:
546+
q = pl.LazyFrame({"x": [1.0]}, schema={"x": pl.Decimal(15, 2)})
547+
548+
q1 = q.select(pl.col.x * pl.col.x)
549+
assert q1.collect_schema() == q1.collect().schema
550+
q1 = q.select(pl.col.x / pl.col.x)
551+
assert q1.collect_schema() == q1.collect().schema
552+
q1 = q.select(pl.col.x - pl.col.x)
553+
assert q1.collect_schema() == q1.collect().schema
554+
q1 = q.select(pl.col.x + pl.col.x)
555+
assert q1.collect_schema() == q1.collect().schema

0 commit comments

Comments
 (0)