Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
DoubleArray sums;
DocValueFormat format;

private DoubleArray compensations;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please declare it next to sums and using the same modifiers


public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
Expand All @@ -55,6 +57,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -76,15 +79,22 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment saying this is Kahan summation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jpountz . I pushed 6c45e08

}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
Expand Down Expand Up @@ -113,7 +123,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
DoubleArray mins;
DoubleArray maxes;

private DoubleArray compensations;


public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat format,
SearchContext context,
Expand All @@ -59,6 +61,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
Expand Down Expand Up @@ -88,6 +91,7 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
Expand All @@ -97,16 +101,22 @@ public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -164,6 +174,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
Expand All @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}

Expand All @@ -71,13 +73,20 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
Expand Down Expand Up @@ -106,6 +115,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
Expand All @@ -36,6 +37,8 @@
import java.io.IOException;
import java.util.function.Consumer;

import static java.util.Collections.singleton;

public class StatsAggregatorTests extends AggregatorTestCase {
static final double TOLERANCE = 1e-10;

Expand Down Expand Up @@ -113,6 +116,27 @@ public void testRandomLongs() throws IOException {
);
}

public void testSummationAccuracy() throws IOException {
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
final String fieldName = "field";
ft.setName(fieldName);
testCase(ft,
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
}
},
stats -> {
assertEquals(15, stats.getCount());
assertEquals(0.9, stats.getAvg(), 0d);
assertEquals(13.5, stats.getSum(), 0d);
assertEquals(1.7, stats.getMax(), 0d);
assertEquals(0.1, stats.getMin(), 0d);
}
);
}

public void testCase(MappedFieldType ft,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalStats> verify) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
Expand Down Expand Up @@ -116,10 +117,28 @@ public void testStringField() throws IOException {
"Re-index with correct docvalues type.", e.getMessage());
}

public void testSummationAccuracy() throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(FIELD_NAME, value)));
}
},
count -> assertEquals(15.3, count.getValue(), 0d),
NumberFieldMapper.NumberType.DOUBLE);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify) throws IOException {
testCase(query, indexer, verify, NumberFieldMapper.NumberType.LONG);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
indexer.accept(indexWriter);
Expand All @@ -128,7 +147,7 @@ private void testCase(Query query,
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName(FIELD_NAME);
fieldType.setHasDocValues(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search.aggregations.metrics.avg;

import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
Expand All @@ -34,9 +35,6 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregator;
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -103,8 +101,28 @@ public void testQueryFiltersAll() throws IOException {
});
}

private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOException> buildIndex, Consumer<InternalAvg> verify)
throws IOException {
public void testSummationAccuracy() throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField("number", value)));
}
},
avg -> assertEquals(0.9, avg.getValue(), 0d),
NumberFieldMapper.NumberType.DOUBLE);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify) throws IOException {
testCase(query, buildIndex, verify, NumberFieldMapper.NumberType.LONG);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
buildIndex.accept(indexWriter);
Expand All @@ -114,7 +132,7 @@ private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOExceptio
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

AvgAggregationBuilder aggregationBuilder = new AvgAggregationBuilder("_name").field("number");
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName("number");

AvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);
Expand Down