diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9dae1c632ef81..a84814b2c3fad 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -25,7 +25,7 @@ protobuf = "3.25.5" jakarta_annotation = "1.3.5" google_http_client = "1.44.1" google_auth = "1.29.0" -tdigest = "3.3" # Warning: Before updating tdigest, ensure its serialization code for MergingDigest hasn't changed +tdigest = "3.3" hdrhistogram = "2.2.2" grpc = "1.68.2" json_smart = "2.5.2" diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java index 31f17b5237ded..1f9e6b0050420 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java @@ -43,8 +43,6 @@ import java.util.Map; import java.util.Objects; -import com.tdunning.math.stats.Centroid; - /** * Implementation of median absolute deviation agg * @@ -59,14 +57,11 @@ static double computeMedianAbsoluteDeviation(TDigestState valuesSketch) { } else { final double approximateMedian = valuesSketch.quantile(0.5); final TDigestState approximatedDeviationsSketch = new TDigestState(valuesSketch.compression()); - for (Centroid centroid : valuesSketch.centroids()) { + valuesSketch.centroids().forEach(centroid -> { final double deviation = Math.abs(approximateMedian - centroid.mean()); - // Weighted add() isn't supported for faster MergingDigest implementation, so add iteratively instead. see - // https://github.com/tdunning/t-digest/issues/167 - for (int i = 0; i < centroid.count(); i++) { - approximatedDeviationsSketch.add(deviation); - } - } + approximatedDeviationsSketch.add(deviation, centroid.count()); + }); + return approximatedDeviationsSketch.quantile(0.5); } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java index b32026a46db80..b61bbcfe1cbbf 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java @@ -31,25 +31,21 @@ package org.opensearch.search.aggregations.metrics; -import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Iterator; -import java.util.List; import com.tdunning.math.stats.AVLTreeDigest; import com.tdunning.math.stats.Centroid; -import com.tdunning.math.stats.MergingDigest; /** * Extension of {@link com.tdunning.math.stats.TDigest} with custom serialization. * * @opensearch.internal */ -public class TDigestState extends MergingDigest { +public class TDigestState extends AVLTreeDigest { private final double compression; @@ -58,64 +54,28 @@ public TDigestState(double compression) { this.compression = compression; } - private TDigestState(double compression, MergingDigest in) { - super(compression); - this.compression = compression; - this.add(List.of(in)); - } - @Override public double compression() { return compression; } public static void write(TDigestState state, StreamOutput out) throws IOException { - if (out.getVersion().before(Version.V_3_1_0)) { - out.writeDouble(state.compression); - out.writeVInt(state.centroidCount()); - for (Centroid centroid : state.centroids()) { - out.writeDouble(centroid.mean()); - out.writeVLong(centroid.count()); - } - } else { - int byteSize = state.byteSize(); - out.writeVInt(byteSize); - ByteBuffer buf = ByteBuffer.allocate(byteSize); - state.asBytes(buf); - out.writeBytes(buf.array()); + out.writeDouble(state.compression); + out.writeVInt(state.centroidCount()); + for (Centroid centroid : state.centroids()) { + out.writeDouble(centroid.mean()); + out.writeVLong(centroid.count()); } } public static TDigestState read(StreamInput in) throws IOException { - if (in.getVersion().before(Version.V_3_1_0)) { - // In older versions TDigestState was based on AVLTreeDigest. Load centroids into this class, then add it to MergingDigest. - double compression = in.readDouble(); - AVLTreeDigest treeDigest = new AVLTreeDigest(compression); - int n = in.readVInt(); - if (n > 0) { - for (int i = 0; i < n; i++) { - treeDigest.add(in.readDouble(), in.readVInt()); - } - TDigestState state = new TDigestState(compression); - state.add(List.of(treeDigest)); - return state; - } - return new TDigestState(compression); - } else { - // For MergingDigest, adding the original centroids in ascending order to a new, empty MergingDigest isn't guaranteed - // to produce a MergingDigest whose centroids are exactly equal to the originals. - // So, use the library's serialization code to ensure we get the exact same centroids, allowing us to compare with equals(). - // The AVLTreeDigest had the same limitation for equals() where it was only guaranteed to return true if the other object was - // produced by de/serializing the object, so this should be fine. - int byteSize = in.readVInt(); - byte[] bytes = new byte[byteSize]; - in.readBytes(bytes, 0, byteSize); - MergingDigest mergingDigest = MergingDigest.fromBytes(ByteBuffer.wrap(bytes)); - if (mergingDigest.centroids().isEmpty()) { - return new TDigestState(mergingDigest.compression()); - } - return new TDigestState(mergingDigest.compression(), mergingDigest); + double compression = in.readDouble(); + TDigestState state = new TDigestState(compression); + int n = in.readVInt(); + for (int i = 0; i < n; i++) { + state.add(in.readDouble(), in.readVInt()); } + return state; } @Override diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java index 0f446ebd55130..78296eddbdc2c 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java @@ -54,7 +54,7 @@ protected InternalTDigestPercentileRanks createTestInstance( Arrays.stream(values).forEach(state::add); // the number of centroids is defined as <= the number of samples inserted - assertTrue(state.centroids().size() <= values.length); + assertTrue(state.centroidCount() <= values.length); return new InternalTDigestPercentileRanks(name, percents, state, keyed, format, metadata); } @@ -66,7 +66,7 @@ protected void assertReduced(InternalTDigestPercentileRanks reduced, List { assertEquals(7L, tdigest.state.size()); - assertEquals(7L, tdigest.state.centroids().size()); + assertEquals(7L, tdigest.state.centroidCount()); assertEquals(5.0d, tdigest.percentile(75), 0.0d); assertEquals("5.0", tdigest.percentileAsString(75)); assertEquals(3.0d, tdigest.percentile(71), 0.0d); @@ -128,7 +128,7 @@ public void testSomeMatchesNumericDocValues() throws IOException { iw.addDocument(singleton(new NumericDocValuesField("number", 0))); }, tdigest -> { assertEquals(tdigest.state.size(), 7L); - assertTrue(tdigest.state.centroids().size() <= 7L); + assertTrue(tdigest.state.centroidCount() <= 7L); assertEquals(8.0d, tdigest.percentile(100), 0.0d); assertEquals("8.0", tdigest.percentileAsString(100)); assertEquals(8.0d, tdigest.percentile(88), 0.0d); @@ -156,7 +156,7 @@ public void testQueryFiltering() throws IOException { testCase(LongPoint.newRangeQuery("row", 1, 4), docs, tdigest -> { assertEquals(4L, tdigest.state.size()); - assertEquals(4L, tdigest.state.centroids().size()); + assertEquals(4L, tdigest.state.centroidCount()); assertEquals(2.0d, tdigest.percentile(100), 0.0d); assertEquals(1.0d, tdigest.percentile(50), 0.0d); assertEquals(1.0d, tdigest.percentile(25), 0.0d); @@ -165,7 +165,7 @@ public void testQueryFiltering() throws IOException { testCase(LongPoint.newRangeQuery("row", 100, 110), docs, tdigest -> { assertEquals(0L, tdigest.state.size()); - assertEquals(0L, tdigest.state.centroids().size()); + assertEquals(0L, tdigest.state.centroidCount()); assertFalse(AggregationInspectionHelper.hasValue(tdigest)); }); }