Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;

public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Expand All @@ -55,6 +56,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -76,15 +78,24 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment saying this is Kahan summation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jpountz . I pushed 6c45e08

}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
Expand Down Expand Up @@ -113,7 +124,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,15 @@ public String getWriteableName() {
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
count += ((InternalAvg) aggregation).count;
sum += ((InternalAvg) aggregation).sum;
double corrected = ((InternalAvg) aggregation).sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,18 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
min = Math.min(min, stats.getMin());
max = Math.max(max, stats.getMax());
sum += stats.getSum();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double corrected = stats.getSum() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
DoubleArray mins;
DoubleArray maxes;

private DoubleArray compensations;


public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat format,
SearchContext context,
Expand All @@ -59,6 +61,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
Expand Down Expand Up @@ -88,6 +91,7 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
Expand All @@ -97,16 +101,24 @@ public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -164,6 +176,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;
DoubleArray sumOfSqrs;
DoubleArray compensationOfSqrs;

public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
Expand All @@ -65,11 +67,13 @@ public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, D
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
sumOfSqrs = bigArrays.newDoubleArray(1, true);
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -95,29 +99,44 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
}

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double sumOfSqr = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum and sum of squires for double values with Kahan summation algorithm
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
sumOfSqr += value * value;
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;

double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sumOfSqrs.increment(bucket, sumOfSqr);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -196,6 +215,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;

import java.io.IOException;
Expand Down Expand Up @@ -73,9 +74,15 @@ public double getValue() {

@Override
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
sum += ((InternalSum) aggregation).sum;
double corrected = ((InternalSum) aggregation).sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
Expand All @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}

Expand All @@ -71,13 +73,22 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
Expand Down Expand Up @@ -106,6 +117,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
Expand All @@ -38,6 +39,8 @@
import java.io.IOException;
import java.util.function.Consumer;

import static java.util.Collections.singleton;

public class ExtendedStatsAggregatorTests extends AggregatorTestCase {
private static final double TOLERANCE = 1e-5;

Expand Down Expand Up @@ -132,6 +135,37 @@ public void testRandomLongs() throws IOException {
);
}

public void testSummationAccuracy() throws IOException {
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
final String fieldName = "field";
ft.setName(fieldName);
testCase(ft,
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
}
},
stats -> {
assertEquals(15, stats.getCount());
assertEquals(0.9, stats.getAvg(), 0d);
assertEquals(13.5, stats.getSum(), 0d);
assertEquals(1.7, stats.getMax(), 0d);
assertEquals(0.1, stats.getMin(), 0d);
assertEquals(0.1, stats.getMin(), 0d);
}
);
testCase(ft,
iw -> {
double[] values = new double[]{2.1, 0.4, 0.4, 0.5, 0.5, 0.7, 0.9, 1.001, 1.222, 1.3, 1.4, 1.5, 1.6, 1.9};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
}
},
stats -> assertEquals(21.095285, stats.getSumOfSquares(), 0d)
);
}

public void testCase(MappedFieldType ft,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalExtendedStats> verify) throws IOException {
Expand Down
Loading