-
Notifications
You must be signed in to change notification settings - Fork 0
18937: feat: Support decimal for variance #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis change extends variance computation in DataFusion to support decimal data types. The variance module now includes decimal-aware accumulators and state management using i256 for precise arithmetic. The ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Code Review: Decimal Support for Variance FunctionsThank you for implementing decimal support for variance functions! This is a substantial and well-structured addition. Here's my comprehensive review: ✅ Strengths
🔍 Issues & Suggestions1. Critical: Potential Precision Loss in i256_to_f64_lossy (variance.rs:88-104)The conversion from i256 to f64 can lose significant precision for large decimal values. While the function is correctly named "_lossy", this could lead to incorrect variance calculations for very large decimal values. Recommendation: Add a comment explaining the precision implications and consider adding a test case with very large decimal values to demonstrate the behavior. 2. Bug: Incorrect Test in variance_decimal_input (variance.rs:1197-1205)The test uses // Current (incorrect):
let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?;
// Should be:
let mut pop_acc = DecimalVarianceAccumulator::<Decimal128Type>::try_new(3, StatsType::Population)?;Impact: High - The new decimal variance implementation is not actually being tested in unit tests, only through SQL logic tests. 3. Numerical Stability Issue (variance.rs:326-330)The clamping of negative numerator to zero is mathematically correct for handling floating-point rounding errors, but it silently hides potential bugs in the calculation: let numerator = if numerator < i256::ZERO {
i256::ZERO
} else {
numerator
};Recommendation: Add a debug assertion or log warning when this occurs, as a negative variance (before clamping) indicates numerical issues. 4. Missing Edge Case TestsThe tests don't cover several important scenarios:
5. Minor: Inconsistent Error Messages
return exec_err!("Cannot retract from empty decimal variance state");6. Performance ConsiderationFor each value, the code converts to i256 and performs i256 arithmetic even for Decimal32/Decimal64 where smaller types would suffice. This is a reasonable trade-off for code simplicity, but worth documenting. 7. Documentation NeededThe new decimal-specific functions ( 📋 Recommended Actions
🔒 Security AssessmentNo security concerns identified. Overflow protection is properly implemented. 🎯 Test Coverage AssessmentSQL logic tests are good, but unit tests have a critical gap (testing wrong accumulator). More edge case coverage needed. Overall AssessmentThis is solid work that correctly implements the core functionality. The main concern is the unit test bug which means the decimal-specific path isn't being properly unit tested. Once that's fixed and edge cases are covered, this will be ready to merge. Issue Reference: Addresses issue apache#18937 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review completed. 1 suggestions posted.
Comment augment review to trigger a new review at any time.
| self.resize(total_num_groups); | ||
| for (row, group_index) in group_indices.iter().enumerate() { | ||
| if let Some(filter) = opt_filter { | ||
| if !filter.value(row) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In DecimalVarianceGroupsAccumulator::update_batch, accessing filter.value(row) without checking validity can mis-handle NULL filters and may read an undefined bit; consider treating NULL as false by checking validity before evaluating the filter (also applies to other locations in the PR if similar logic appears).
🤖 Was this useful? React with 👍 or 👎
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The Augment AI reviewer is correct! Before calling filter.value(row) it should check whether the row is valid with !filter.is_valid(row). Prevents wrong behavior of the check in case the row is Null.
| let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?; | ||
| let sample_input = [array]; | ||
| sample_acc.update_batch(&sample_input)?; | ||
| assert_variance(sample_acc.evaluate()?, 11606.257924736841); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Unit test doesn't test the new decimal variance accumulator
The variance_decimal_input test creates VarianceAccumulator directly instead of DecimalVarianceAccumulator. The VarianceAccumulator casts decimal inputs to Float64, so this test doesn't actually exercise the new decimal-specific variance implementation added in this PR. The test name and setup suggest it's validating decimal variance support, but it's actually testing the Float64 path. The new DecimalVarianceAccumulator code remains untested by this unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The Bugbot AI reviewer is correct that the unit test uses the old accumulator (VarianceAccumulator) instead of the one introduced in this Pull Request. Prevents regressions due to testing the wrong implementation.
| stats_type, | ||
| _marker: PhantomData, | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Groups accumulator missing scale validation present in regular accumulator
DecimalVarianceAccumulator::try_new validates that scale doesn't exceed DECIMAL256_MAX_SCALE and returns an error if it does. However, DecimalVarianceGroupsAccumulator::new has no such validation. This inconsistency means the same invalid scale would cause an error via the regular accumulator path but silently proceed via the groups accumulator path, potentially producing incorrect results when the scale factor calculation overflows.
Additional Locations (1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The Bugbot AI reviewer is correct! DecimalVarianceGroupsAccumulator should check the value of "scale" to be lower than DECIMAL256_MAX_SCALE. Prevents wrong calculations due to too big scale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
datafusion/functions-aggregate/src/variance.rs (3)
361-463: Scalar decimal accumulator is well-structured; minor robustness tweaks possible
DecimalVarianceAccumulatorcorrectly:
- Uses the decimal-native primitive array (
as_primitive::<T>()) without unnecessary casts.- Implements
update_batch,retract_batch, andmerge_batchin terms of the sharedDecimalVarianceState.- Returns the expected Float64 scalar via
state.variance.You might optionally assert
values.len() == 1inupdate_batch/retract_batchfor defensive programming, mirroring patterns used elsewhere, but this is not required for correctness.
743-831: VariancePopulation type checks and decimal routing look good; consider mirroring in sample
VariancePopulation::return_typenow explicitly accepts both numeric and decimal inputs viais_numeric_or_decimal, and the accumulator/group-accumulator creation mirrors the sample case withStatsType::Population. That’s correct and keeps error messages clear for non‑numeric types.For consistency, you might optionally add the same explicit type check to
VarianceSample::return_type(currently it ignoresarg_types) so both variants behave identically at planning time.
1172-1217: Decimal variance tests are correct but don’t exercise the new decimal accumulatorThe new unit tests:
- Verify that
VariancePopulation::return_typeaccepts a decimal type.- Check numeric variance results for a Decimal128 array, but they call
VarianceAccumulator::try_new(the float-based implementation) instead of the newDecimalVarianceAccumulatoror theVariancePopulation/VarianceSampleUDF paths.The expectations (11025.9450285 and 11606.257924736841) match the analytical population and sample variance for the constructed data, which is good. To directly test the new decimal path, consider adding tests that:
- Construct a
VariancePopulation/VarianceSampleaccumulator via theiraccumulatormethods with a decimalexpr_field, or- Instantiate
DecimalVarianceAccumulator::<Decimal128Type>::try_new(3, StatsType::Population/Sample)and runupdate_batch/evaluateon the same data.This would guard against regressions in the decimal-specific state/merge logic.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
datafusion/functions-aggregate/src/variance.rs(12 hunks)datafusion/sqllogictest/test_files/aggregate.slt(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
datafusion/functions-aggregate/src/variance.rs (1)
datafusion/expr/src/utils.rs (1)
format_state_name(1257-1259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Cursor Bugbot
- GitHub Check: claude-review
- GitHub Check: Analyze (rust)
🔇 Additional comments (5)
datafusion/sqllogictest/test_files/aggregate.slt (1)
5632-5648: Decimal variance SQL tests look consistent and well-chosenThe new
aggregate_decimal_variancecases validate both population and sample variance on decimal(10,3), including groupedvar_popwith very small within-group variance and correctarrow_typeofchecks. The expected values match the analytical variance for the constructed data, so this should give good coverage of the new decimal variance path.datafusion/functions-aggregate/src/variance.rs (4)
70-135: Decimal variance helpers and state layout are soundThe shared
variance_signature()plusdecimal_scale/decimal_variance_state_fieldswiring cleanly extends variance to decimals while preserving numeric behavior. State encoding as(UInt64, FixedSizeBinary(32), FixedSizeBinary(32))viai256_to_scalar/i256_from_bytesis consistent and matches the 256‑bit backing type, so merge paths can reliably reconstruct state. The lossy conversion tof64is acceptable given the Float64 return type.
137-221: Type gating and decimal accumulator factory behavior look correct
is_numeric_or_decimaland thecreate_decimal_variance_*helpers correctly distinguish numeric vs decimal inputs and route each decimal variant (32/64/128/256) to the appropriate accumulator. Non‑decimal types still fall back to the existing Welford-based variance accumulators, so existing behavior is preserved.
222-359: DecimalVarianceState math and overflow handling are correctThe state (
count,sum,sum_squares) andvarianceimplementation mirror the standard formulas:
- Population:
(n * Σx² − (Σx)²) / n²- Sample:
(n * Σx² − (Σx)²) / (n * (n−1))with all intermediate arithmetic done in
i256and guarded bychecked_*plus a clear"Decimal variance overflow"error. Scaling by10^(2*scale)is applied only at the end, so decimal scaling is handled exactly. Clamping negative numerators to zero is a safe guard against rare intermediate anomalies, andcount == 0/count <= 1edge cases are correctly mapped toNone.
629-709: VarianceSample integration with decimal path is consistentUpdating
VarianceSampleto use the sharedvariance_signature()plus:
- Decimal-aware
state_fields(switching to the fixed-size binary state for decimal inputs), and- Early dispatch to
create_decimal_variance_accumulator/create_decimal_variance_groups_accumulatorcleanly extends
var/var_sampto decimals while keepingVAR(DISTINCT)behavior unchanged. The Float64 return type remains appropriate.
| #[derive(Debug)] | ||
| struct DecimalVarianceGroupsAccumulator<T> | ||
| where | ||
| T: DecimalType + ArrowNumericType + Debug, | ||
| T::Native: DecimalNative, | ||
| { | ||
| states: Vec<DecimalVarianceState>, | ||
| scale: i8, | ||
| stats_type: StatsType, | ||
| _marker: PhantomData<T>, | ||
| } | ||
|
|
||
| impl<T> DecimalVarianceGroupsAccumulator<T> | ||
| where | ||
| T: DecimalType + ArrowNumericType + Debug, | ||
| T::Native: DecimalNative, | ||
| { | ||
| fn new(scale: i8, stats_type: StatsType) -> Self { | ||
| Self { | ||
| states: Vec::new(), | ||
| scale, | ||
| stats_type, | ||
| _marker: PhantomData, | ||
| } | ||
| } | ||
|
|
||
| fn resize(&mut self, total_num_groups: usize) { | ||
| if self.states.len() < total_num_groups { | ||
| self.states | ||
| .resize(total_num_groups, DecimalVarianceState::default()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl<T> GroupsAccumulator for DecimalVarianceGroupsAccumulator<T> | ||
| where | ||
| T: DecimalType + ArrowNumericType + Debug, | ||
| T::Native: DecimalNative, | ||
| { | ||
| fn update_batch( | ||
| &mut self, | ||
| values: &[ArrayRef], | ||
| group_indices: &[usize], | ||
| opt_filter: Option<&BooleanArray>, | ||
| total_num_groups: usize, | ||
| ) -> Result<()> { | ||
| let array = values[0].as_primitive::<T>(); | ||
| self.resize(total_num_groups); | ||
| for (row, group_index) in group_indices.iter().enumerate() { | ||
| if let Some(filter) = opt_filter { | ||
| if !filter.value(row) { | ||
| continue; | ||
| } | ||
| } | ||
| if array.is_null(row) { | ||
| continue; | ||
| } | ||
| let value = array.value(row).to_i256(); | ||
| self.states[*group_index].update(value)?; | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn merge_batch( | ||
| &mut self, | ||
| values: &[ArrayRef], | ||
| group_indices: &[usize], | ||
| _opt_filter: Option<&BooleanArray>, | ||
| total_num_groups: usize, | ||
| ) -> Result<()> { | ||
| let counts = downcast_value!(values[0], UInt64Array); | ||
| let sums = downcast_value!(values[1], FixedSizeBinaryArray); | ||
| let sum_squares = downcast_value!(values[2], FixedSizeBinaryArray); | ||
| self.resize(total_num_groups); | ||
|
|
||
| for (row, group_index) in group_indices.iter().enumerate() { | ||
| if counts.is_null(row) { | ||
| continue; | ||
| } | ||
| let count = counts.value(row); | ||
| if count == 0 { | ||
| continue; | ||
| } | ||
| let sum = i256_from_bytes(sums.value(row))?; | ||
| let sum_sq = i256_from_bytes(sum_squares.value(row))?; | ||
| let other = DecimalVarianceState { | ||
| count, | ||
| sum, | ||
| sum_squares: sum_sq, | ||
| }; | ||
| self.states[*group_index].merge(&other)?; | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { | ||
| let states = emit_to.take_needed(&mut self.states); | ||
| let mut builder = Float64Builder::with_capacity(states.len()); | ||
| for state in &states { | ||
| match state.variance(self.stats_type, self.scale)? { | ||
| Some(value) => builder.append_value(value), | ||
| None => builder.append_null(), | ||
| } | ||
| } | ||
| Ok(Arc::new(builder.finish())) | ||
| } | ||
|
|
||
| fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> { | ||
| let states = emit_to.take_needed(&mut self.states); | ||
| let mut counts = UInt64Builder::with_capacity(states.len()); | ||
| let mut sums = FixedSizeBinaryBuilder::with_capacity( | ||
| states.len(), | ||
| DECIMAL_VARIANCE_BINARY_SIZE, | ||
| ); | ||
| let mut sum_squares = FixedSizeBinaryBuilder::with_capacity( | ||
| states.len(), | ||
| DECIMAL_VARIANCE_BINARY_SIZE, | ||
| ); | ||
|
|
||
| for state in states { | ||
| counts.append_value(state.count); | ||
| sums.append_value(state.sum.to_le_bytes())?; | ||
| sum_squares.append_value(state.sum_squares.to_le_bytes())?; | ||
| } | ||
|
|
||
| Ok(vec![ | ||
| Arc::new(counts.finish()), | ||
| Arc::new(sums.finish()), | ||
| Arc::new(sum_squares.finish()), | ||
| ]) | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| self.states.capacity() * size_of::<DecimalVarianceState>() | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FILTER with nullable boolean can be misinterpreted in decimal groups accumulator
In DecimalVarianceGroupsAccumulator::update_batch, the filter handling is:
if let Some(filter) = opt_filter {
if !filter.value(row) {
continue;
}
}This ignores the validity bitmap. If the filter expression can produce NULLs, those rows should be excluded (SQL FILTER (WHERE ...) treats NULL as false), but here a NULL row will be treated according to the underlying, undefined value bit in the data buffer, which can randomly include or exclude rows.
To align with existing aggregate semantics and the accumulate helper used by VarianceGroupsAccumulator, you should also check filter.is_valid(row) before using value(row), for example:
- if let Some(filter) = opt_filter {
- if !filter.value(row) {
- continue;
- }
- }
+ if let Some(filter) = opt_filter {
+ if !filter.is_valid(row) || !filter.value(row) {
+ continue;
+ }
+ }🤖 Prompt for AI Agents
In datafusion/functions-aggregate/src/variance.rs around lines 465 to 600, the
update_batch filter check uses filter.value(row) without checking validity, so
NULL filter entries are interpreted from undefined bits; change the check to
exclude rows when the filter is NULL or false by verifying filter.is_valid(row)
before reading value (i.e., continue if !filter.is_valid(row) ||
!filter.value(row)). Ensure the logic matches the other accumulators: treat NULL
as false and skip the row.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct! Before calling filter.value(row) it should check whether the row is valid with !filter.is_valid(row). Prevents wrong behavior of the check in case the row is Null.
value:good-to-have; category:documentation; feedback:The Claude AI reviewer is correct that this function would benefit of having some documentation about its limits. Currently its only "documentation" is "lossy" in its name but it does not explain in which cases it would loss precision. Prevents wrong results due to big expectations. |
value:useful; category:bug; feedback:The Claude AI reviewer is correct that the unit test uses the old accumulator (VarianceAccumulator) instead of the one introduced in this Pull Request. Prevents regressions due to testing the wrong implementation. |
18937: To review by AI