Skip to content

Commit cf13a80

Browse files
[Bug] Star tree indexing - kahan summation fix (#18462)
* Star tree kahan summation fix Signed-off-by: bharath-techie <[email protected]>
1 parent 6fd9329 commit cf13a80

File tree

15 files changed

+519
-106
lines changed

15 files changed

+519
-106
lines changed

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/aggregators/SumValueAggregator.java

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88
package org.opensearch.index.compositeindex.datacube.startree.aggregators;
99

10+
import org.opensearch.index.compositeindex.datacube.startree.utils.CompensatedSumType;
1011
import org.opensearch.index.mapper.FieldValueConverter;
1112
import org.opensearch.index.mapper.NumberFieldMapper;
1213
import org.opensearch.search.aggregations.metrics.CompensatedSum;
@@ -21,87 +22,107 @@
2122
*
2223
* @opensearch.experimental
2324
*/
24-
class SumValueAggregator implements ValueAggregator<Double> {
25+
class SumValueAggregator implements ValueAggregator<CompensatedSum> {
2526

2627
private final FieldValueConverter fieldValueConverter;
28+
private final CompensatedSumType compensatedSumConverter;
2729
private static final FieldValueConverter VALUE_AGGREGATOR_TYPE = NumberFieldMapper.NumberType.DOUBLE;
2830

29-
private CompensatedSum kahanSummation = new CompensatedSum(0, 0);
30-
3131
public SumValueAggregator(FieldValueConverter fieldValueConverter) {
3232
this.fieldValueConverter = fieldValueConverter;
33+
this.compensatedSumConverter = new CompensatedSumType();
3334
}
3435

3536
@Override
3637
public FieldValueConverter getAggregatedValueType() {
37-
return VALUE_AGGREGATOR_TYPE;
38+
return compensatedSumConverter;
3839
}
3940

4041
@Override
41-
public Double getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue) {
42-
kahanSummation.reset(0, 0);
42+
public CompensatedSum getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue) {
43+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
4344
// add takes care of the sum and compensation internally
4445
if (segmentDocValue != null) {
45-
kahanSummation.add(fieldValueConverter.toDoubleValue(segmentDocValue));
46+
kahanSummation.reset(fieldValueConverter.toDoubleValue(segmentDocValue), 0);
4647
} else {
47-
kahanSummation.add(getIdentityMetricValue());
48+
kahanSummation.reset(getIdentityMetricDoubleValue(), 0);
4849
}
49-
return kahanSummation.value();
50+
return kahanSummation;
5051
}
5152

5253
// we have overridden this method because the reset with sum and compensation helps us keep
5354
// track of precision and avoids a potential loss in accuracy of sums.
5455
@Override
55-
public Double mergeAggregatedValueAndSegmentValue(Double value, Long segmentDocValue) {
56-
assert value == null || kahanSummation.value() == value;
56+
public CompensatedSum mergeAggregatedValueAndSegmentValue(CompensatedSum value, Long segmentDocValue) {
57+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
58+
if (value != null) {
59+
kahanSummation.reset(value.value(), value.delta());
60+
}
5761
// add takes care of the sum and compensation internally
5862
if (segmentDocValue != null) {
5963
kahanSummation.add(fieldValueConverter.toDoubleValue(segmentDocValue));
6064
} else {
61-
kahanSummation.add(getIdentityMetricValue());
65+
kahanSummation.add(getIdentityMetricDoubleValue());
6266
}
63-
return kahanSummation.value();
67+
return kahanSummation;
6468
}
6569

6670
@Override
67-
public Double mergeAggregatedValues(Double value, Double aggregatedValue) {
68-
assert aggregatedValue == null || kahanSummation.value() == aggregatedValue;
69-
// add takes care of the sum and compensation internally
71+
public CompensatedSum mergeAggregatedValues(CompensatedSum value, CompensatedSum aggregatedValue) {
72+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
73+
if (aggregatedValue != null) {
74+
kahanSummation.reset(aggregatedValue.value(), aggregatedValue.delta());
75+
}
7076
if (value != null) {
71-
kahanSummation.add(value);
77+
kahanSummation.add(value.value(), value.delta());
7278
} else {
73-
kahanSummation.add(getIdentityMetricValue());
79+
kahanSummation.add(getIdentityMetricDoubleValue());
7480
}
75-
return kahanSummation.value();
81+
return kahanSummation;
7682
}
7783

7884
@Override
79-
public Double getInitialAggregatedValue(Double value) {
80-
kahanSummation.reset(0, 0);
85+
public CompensatedSum getInitialAggregatedValue(CompensatedSum value) {
86+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
8187
// add takes care of the sum and compensation internally
82-
if (value != null) {
83-
kahanSummation.add(value);
88+
if (value == null) {
89+
kahanSummation.reset(getIdentityMetricDoubleValue(), 0);
8490
} else {
85-
kahanSummation.add(getIdentityMetricValue());
91+
kahanSummation.reset(value.value(), value.delta());
8692
}
87-
return kahanSummation.value();
93+
return kahanSummation;
8894
}
8995

9096
@Override
91-
public Double toAggregatedValueType(Long value) {
97+
public CompensatedSum toAggregatedValueType(Long value) {
98+
CompensatedSum kahanSummation = new CompensatedSum(0, 0);
9299
try {
93100
if (value == null) {
94-
return getIdentityMetricValue();
101+
kahanSummation.reset(getIdentityMetricDoubleValue(), 0);
102+
return kahanSummation;
95103
}
96-
return VALUE_AGGREGATOR_TYPE.toDoubleValue(value);
104+
kahanSummation.reset(VALUE_AGGREGATOR_TYPE.toDoubleValue(value), 0);
105+
return kahanSummation;
97106
} catch (Exception e) {
98107
throw new IllegalStateException("Cannot convert " + value + " to sortable aggregation type", e);
99108
}
100109
}
101110

111+
/**
112+
* Since getIdentityMetricValue is called for every null document, and it creates a new object,
113+
* in this class, calling getIdentityMetricDoubleValue to avoid initializing an object
114+
*/
115+
private double getIdentityMetricDoubleValue() {
116+
return 0.0;
117+
}
118+
119+
/**
120+
* Since getIdentityMetricValue is called for every null document, and it creates a new object,
121+
* in this class, calling getIdentityMetricDoubleValue to avoid initializing an object
122+
*/
102123
@Override
103-
public Double getIdentityMetricValue() {
124+
public CompensatedSum getIdentityMetricValue() {
104125
// in present aggregations, if the metric behind sum is missing, we treat it as 0
105-
return 0D;
126+
return new CompensatedSum(0, 0);
106127
}
107128
}

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/builder/AbstractDocumentsFileManager.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import org.opensearch.index.compositeindex.datacube.startree.StarTreeDocument;
2222
import org.opensearch.index.compositeindex.datacube.startree.StarTreeField;
2323
import org.opensearch.index.compositeindex.datacube.startree.aggregators.MetricAggregatorInfo;
24+
import org.opensearch.index.compositeindex.datacube.startree.utils.CompensatedSumType;
2425
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeDocumentBitSetUtil;
2526
import org.opensearch.index.mapper.FieldValueConverter;
27+
import org.opensearch.search.aggregations.metrics.CompensatedSum;
2628

2729
import java.io.Closeable;
2830
import java.io.IOException;
@@ -125,6 +127,15 @@ protected void writeMetrics(StarTreeDocument starTreeDocument, ByteBuffer buffer
125127
} else {
126128
buffer.putLong(starTreeDocument.metrics[i] == null ? 0L : (Long) starTreeDocument.metrics[i]);
127129
}
130+
} else if (aggregatedValueType instanceof CompensatedSumType) {
131+
if (isAggregatedDoc) {
132+
long val = NumericUtils.doubleToSortableLong(
133+
starTreeDocument.metrics[i] == null ? 0.0 : ((CompensatedSum) starTreeDocument.metrics[i]).value()
134+
);
135+
buffer.putLong(val);
136+
} else {
137+
buffer.putLong(starTreeDocument.metrics[i] == null ? 0L : (Long) starTreeDocument.metrics[i]);
138+
}
128139
} else {
129140
throw new IllegalStateException("Unsupported metric type");
130141
}
@@ -232,6 +243,14 @@ private long readMetrics(RandomAccessInput input, long offset, int numMetrics, O
232243
metrics[i] = val;
233244
}
234245
offset += Long.BYTES;
246+
} else if (aggregatedValueType instanceof CompensatedSumType) {
247+
long val = input.readLong(offset);
248+
if (isAggregatedDoc) {
249+
metrics[i] = new CompensatedSum(aggregatedValueType.toDoubleValue(val), 0);
250+
} else {
251+
metrics[i] = val;
252+
}
253+
offset += Long.BYTES;
235254
} else {
236255
throw new IllegalStateException("Unsupported metric type");
237256
}

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/builder/BaseStarTreeBuilder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
4343
import org.opensearch.index.compositeindex.datacube.startree.node.InMemoryTreeNode;
4444
import org.opensearch.index.compositeindex.datacube.startree.node.StarTreeNodeType;
45+
import org.opensearch.index.compositeindex.datacube.startree.utils.CompensatedSumType;
4546
import org.opensearch.index.compositeindex.datacube.startree.utils.SequentialDocValuesIterator;
4647
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;
4748
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedSetStarTreeValuesIterator;
@@ -50,6 +51,7 @@
5051
import org.opensearch.index.mapper.FieldValueConverter;
5152
import org.opensearch.index.mapper.Mapper;
5253
import org.opensearch.index.mapper.MapperService;
54+
import org.opensearch.search.aggregations.metrics.CompensatedSum;
5355

5456
import java.io.IOException;
5557
import java.util.ArrayList;
@@ -490,6 +492,13 @@ private void createSortedDocValuesIndices(DocValuesConsumer docValuesConsumer, A
490492
NumericUtils.doubleToSortableLong((Double) starTreeDocument.metrics[i])
491493
);
492494
}
495+
} else if (aggregatedValueType instanceof CompensatedSumType) {
496+
if (starTreeDocument.metrics[i] != null) {
497+
((SortedNumericDocValuesWriterWrapper) (metricWriters.get(i))).addValue(
498+
docId,
499+
NumericUtils.doubleToSortableLong(((CompensatedSum) starTreeDocument.metrics[i]).value())
500+
);
501+
}
493502
} else {
494503
throw new IllegalStateException("Unknown metric doc value type");
495504
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.index.compositeindex.datacube.startree.utils;
10+
11+
import org.opensearch.index.mapper.FieldValueConverter;
12+
13+
import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.DOUBLE;
14+
15+
/**
16+
* Field value converter for CompensatedSum - it's just a wrapper over Double
17+
*
18+
* @opensearch.internal
19+
*/
20+
public class CompensatedSumType implements FieldValueConverter {
21+
22+
public CompensatedSumType() {}
23+
24+
@Override
25+
public double toDoubleValue(long value) {
26+
return DOUBLE.toDoubleValue(value);
27+
}
28+
}

server/src/main/java/org/opensearch/search/aggregations/metrics/CompensatedSum.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
package org.opensearch.search.aggregations.metrics;
3434

35+
import java.util.Objects;
36+
3537
/**
3638
* Used to calculate sums using the Kahan summation algorithm.
3739
*
@@ -110,4 +112,22 @@ public CompensatedSum add(double value, double delta) {
110112
return this;
111113
}
112114

115+
@Override
116+
public boolean equals(Object o) {
117+
if (this == o) return true;
118+
if (o == null || getClass() != o.getClass()) return false;
119+
CompensatedSum that = (CompensatedSum) o;
120+
return Double.compare(that.value, value) == 0 && Double.compare(that.delta, delta) == 0;
121+
}
122+
123+
@Override
124+
public int hashCode() {
125+
return Objects.hash(value, delta);
126+
}
127+
128+
@Override
129+
public String toString() {
130+
return value + " " + delta;
131+
}
132+
113133
}

server/src/test/java/org/opensearch/index/compositeindex/datacube/startree/StarTreeTestUtils.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.index.compositeindex.datacube.startree.utils.SequentialDocValuesIterator;
2222
import org.opensearch.index.mapper.CompositeMappedFieldType;
2323
import org.opensearch.index.mapper.FieldValueConverter;
24+
import org.opensearch.search.aggregations.metrics.CompensatedSum;
2425

2526
import java.io.IOException;
2627
import java.util.ArrayDeque;
@@ -150,6 +151,18 @@ public static void assertStarTreeDocuments(StarTreeDocument[] starTreeDocuments,
150151
for (int mi = 0; mi < resultStarTreeDocument.metrics.length; mi++) {
151152
if (expectedStarTreeDocument.metrics[mi] instanceof Long) {
152153
assertEquals(((Long) expectedStarTreeDocument.metrics[mi]).doubleValue(), resultStarTreeDocument.metrics[mi]);
154+
} else if (resultStarTreeDocument.metrics[mi] instanceof CompensatedSum) {
155+
if (expectedStarTreeDocument.metrics[mi] instanceof CompensatedSum) {
156+
assertEquals(expectedStarTreeDocument.metrics[mi], resultStarTreeDocument.metrics[mi]);
157+
} else {
158+
assertEquals((expectedStarTreeDocument.metrics[mi]), ((CompensatedSum) resultStarTreeDocument.metrics[mi]).value());
159+
}
160+
} else if (expectedStarTreeDocument.metrics[mi] instanceof CompensatedSum) {
161+
if (resultStarTreeDocument.metrics[mi] instanceof CompensatedSum) {
162+
assertEquals(expectedStarTreeDocument.metrics[mi], resultStarTreeDocument.metrics[mi]);
163+
} else {
164+
assertEquals(((CompensatedSum) expectedStarTreeDocument.metrics[mi]).value(), resultStarTreeDocument.metrics[mi]);
165+
}
153166
} else {
154167
assertEquals(expectedStarTreeDocument.metrics[mi], resultStarTreeDocument.metrics[mi]);
155168
}

server/src/test/java/org/opensearch/index/compositeindex/datacube/startree/aggregators/AbstractValueAggregatorTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.opensearch.index.mapper.FieldValueConverter;
1414
import org.opensearch.index.mapper.NumberFieldMapper;
15+
import org.opensearch.search.aggregations.metrics.CompensatedSum;
1516
import org.opensearch.test.OpenSearchTestCase;
1617
import org.junit.Before;
1718

@@ -64,6 +65,10 @@ public void testGetInitialAggregatedValueForSegmentDocValue() {
6465
long randomLong = randomLong();
6566
if (aggregator instanceof CountValueAggregator) {
6667
assertEquals(CountValueAggregator.DEFAULT_INITIAL_VALUE, aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong()));
68+
} else if (aggregator instanceof SumValueAggregator) {
69+
CompensatedSum sum = new CompensatedSum(0, 0);
70+
sum.add(fieldValueConverter.toDoubleValue(randomLong));
71+
assertEquals(sum, aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong));
6772
} else {
6873
assertEquals(fieldValueConverter.toDoubleValue(randomLong), aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong));
6974
}

server/src/test/java/org/opensearch/index/compositeindex/datacube/startree/aggregators/StaticValueAggregatorTests.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,28 @@ public void testKahanSummation() {
2121
double expected = 1;
2222

2323
// initializing our sum aggregator to derive exact sum using kahan summation
24-
double aggregatedValue = getAggregatedValue(numbers);
25-
assertEquals(expected, aggregatedValue, 0);
24+
CompensatedSum aggregatedSum = getAggregatedValue(numbers);
25+
assertEquals(expected, aggregatedSum.value(), 0);
2626

2727
// assert kahan summation plain logic with our aggregated value
2828
double actual = kahanSum(numbers);
29-
assertEquals(actual, aggregatedValue, 0);
29+
assertEquals(actual, aggregatedSum.value(), 0);
3030

3131
// assert that normal sum fails for this case
3232
double normalSum = normalSum(numbers);
3333
assertNotEquals(expected, normalSum, 0);
3434
assertNotEquals(actual, normalSum, 0);
35-
assertNotEquals(aggregatedValue, normalSum, 0);
36-
35+
assertNotEquals(aggregatedSum.value(), normalSum, 0);
3736
}
3837

39-
private static double getAggregatedValue(double[] numbers) {
40-
// explicitly took double to test for most precision
41-
// hard to run similar tests for different data types dynamically as inputs and precision vary
38+
private static CompensatedSum getAggregatedValue(double[] numbers) {
4239
SumValueAggregator aggregator = new SumValueAggregator(NumberFieldMapper.NumberType.DOUBLE);
43-
double aggregatedValue = aggregator.getInitialAggregatedValueForSegmentDocValue(NumericUtils.doubleToSortableLong(numbers[0]));
44-
aggregatedValue = aggregator.mergeAggregatedValueAndSegmentValue(aggregatedValue, NumericUtils.doubleToSortableLong(numbers[1]));
45-
aggregatedValue = aggregator.mergeAggregatedValueAndSegmentValue(aggregatedValue, NumericUtils.doubleToSortableLong(numbers[2]));
40+
long sortableLong1 = NumericUtils.doubleToSortableLong(numbers[0]);
41+
CompensatedSum aggregatedValue = aggregator.getInitialAggregatedValueForSegmentDocValue(sortableLong1);
42+
long sortableLong2 = NumericUtils.doubleToSortableLong(numbers[1]);
43+
aggregatedValue = aggregator.mergeAggregatedValueAndSegmentValue(aggregatedValue, sortableLong2);
44+
long sortableLong3 = NumericUtils.doubleToSortableLong(numbers[2]);
45+
aggregatedValue = aggregator.mergeAggregatedValueAndSegmentValue(aggregatedValue, sortableLong3);
4646
return aggregatedValue;
4747
}
4848

@@ -129,5 +129,4 @@ public void testMinAggregatorExtremeValues_Infinity() {
129129
}
130130
assertEquals(expected, aggregatedValue, 0);
131131
}
132-
133132
}

0 commit comments

Comments
 (0)