diff --git a/CHANGELOG.md b/CHANGELOG.md index 3254312c437a3..93b210d910597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Create generic DocRequest to better categorize ActionRequests ([#18269](https://github.com/opensearch-project/OpenSearch/pull/18269))) +- Change implementation for `percentiles` aggregation for latency improvement [#18124](https://github.com/opensearch-project/OpenSearch/pull/18124) ### Dependencies - Update Apache Lucene from 10.1.0 to 10.2.1 ([#17961](https://github.com/opensearch-project/OpenSearch/pull/17961)) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 838c27e6c2c31..025ead2845d68 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -25,7 +25,7 @@ protobuf = "3.25.5" jakarta_annotation = "1.3.5" google_http_client = "1.44.1" google_auth = "1.29.0" -tdigest = "3.3" +tdigest = "3.3" # Warning: Before updating tdigest, ensure its serialization code for MergingDigest hasn't changed hdrhistogram = "2.2.2" grpc = "1.68.2" json_smart = "2.5.2" diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java index 1f9e6b0050420..31f17b5237ded 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalMedianAbsoluteDeviation.java @@ -43,6 +43,8 @@ import java.util.Map; import java.util.Objects; +import com.tdunning.math.stats.Centroid; + /** * Implementation of median absolute deviation agg * @@ -57,11 +59,14 @@ static double computeMedianAbsoluteDeviation(TDigestState valuesSketch) { } else { final double approximateMedian = valuesSketch.quantile(0.5); final TDigestState approximatedDeviationsSketch = new TDigestState(valuesSketch.compression()); - valuesSketch.centroids().forEach(centroid -> { + for (Centroid centroid : valuesSketch.centroids()) { final double deviation = Math.abs(approximateMedian - centroid.mean()); - approximatedDeviationsSketch.add(deviation, centroid.count()); - }); - + // Weighted add() isn't supported for faster MergingDigest implementation, so add iteratively instead. see + // https://github.com/tdunning/t-digest/issues/167 + for (int i = 0; i < centroid.count(); i++) { + approximatedDeviationsSketch.add(deviation); + } + } return approximatedDeviationsSketch.quantile(0.5); } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java index b61bbcfe1cbbf..b32026a46db80 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/TDigestState.java @@ -31,21 +31,25 @@ package org.opensearch.search.aggregations.metrics; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import com.tdunning.math.stats.AVLTreeDigest; import com.tdunning.math.stats.Centroid; +import com.tdunning.math.stats.MergingDigest; /** * Extension of {@link com.tdunning.math.stats.TDigest} with custom serialization. * * @opensearch.internal */ -public class TDigestState extends AVLTreeDigest { +public class TDigestState extends MergingDigest { private final double compression; @@ -54,28 +58,64 @@ public TDigestState(double compression) { this.compression = compression; } + private TDigestState(double compression, MergingDigest in) { + super(compression); + this.compression = compression; + this.add(List.of(in)); + } + @Override public double compression() { return compression; } public static void write(TDigestState state, StreamOutput out) throws IOException { - out.writeDouble(state.compression); - out.writeVInt(state.centroidCount()); - for (Centroid centroid : state.centroids()) { - out.writeDouble(centroid.mean()); - out.writeVLong(centroid.count()); + if (out.getVersion().before(Version.V_3_1_0)) { + out.writeDouble(state.compression); + out.writeVInt(state.centroidCount()); + for (Centroid centroid : state.centroids()) { + out.writeDouble(centroid.mean()); + out.writeVLong(centroid.count()); + } + } else { + int byteSize = state.byteSize(); + out.writeVInt(byteSize); + ByteBuffer buf = ByteBuffer.allocate(byteSize); + state.asBytes(buf); + out.writeBytes(buf.array()); } } public static TDigestState read(StreamInput in) throws IOException { - double compression = in.readDouble(); - TDigestState state = new TDigestState(compression); - int n = in.readVInt(); - for (int i = 0; i < n; i++) { - state.add(in.readDouble(), in.readVInt()); + if (in.getVersion().before(Version.V_3_1_0)) { + // In older versions TDigestState was based on AVLTreeDigest. Load centroids into this class, then add it to MergingDigest. + double compression = in.readDouble(); + AVLTreeDigest treeDigest = new AVLTreeDigest(compression); + int n = in.readVInt(); + if (n > 0) { + for (int i = 0; i < n; i++) { + treeDigest.add(in.readDouble(), in.readVInt()); + } + TDigestState state = new TDigestState(compression); + state.add(List.of(treeDigest)); + return state; + } + return new TDigestState(compression); + } else { + // For MergingDigest, adding the original centroids in ascending order to a new, empty MergingDigest isn't guaranteed + // to produce a MergingDigest whose centroids are exactly equal to the originals. + // So, use the library's serialization code to ensure we get the exact same centroids, allowing us to compare with equals(). + // The AVLTreeDigest had the same limitation for equals() where it was only guaranteed to return true if the other object was + // produced by de/serializing the object, so this should be fine. + int byteSize = in.readVInt(); + byte[] bytes = new byte[byteSize]; + in.readBytes(bytes, 0, byteSize); + MergingDigest mergingDigest = MergingDigest.fromBytes(ByteBuffer.wrap(bytes)); + if (mergingDigest.centroids().isEmpty()) { + return new TDigestState(mergingDigest.compression()); + } + return new TDigestState(mergingDigest.compression(), mergingDigest); } - return state; } @Override diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java index 78296eddbdc2c..0f446ebd55130 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalTDigestPercentilesRanksTests.java @@ -54,7 +54,7 @@ protected InternalTDigestPercentileRanks createTestInstance( Arrays.stream(values).forEach(state::add); // the number of centroids is defined as <= the number of samples inserted - assertTrue(state.centroidCount() <= values.length); + assertTrue(state.centroids().size() <= values.length); return new InternalTDigestPercentileRanks(name, percents, state, keyed, format, metadata); } @@ -66,7 +66,7 @@ protected void assertReduced(InternalTDigestPercentileRanks reduced, List { assertEquals(7L, tdigest.state.size()); - assertEquals(7L, tdigest.state.centroidCount()); + assertEquals(7L, tdigest.state.centroids().size()); assertEquals(5.0d, tdigest.percentile(75), 0.0d); assertEquals("5.0", tdigest.percentileAsString(75)); assertEquals(3.0d, tdigest.percentile(71), 0.0d); @@ -128,7 +128,7 @@ public void testSomeMatchesNumericDocValues() throws IOException { iw.addDocument(singleton(new NumericDocValuesField("number", 0))); }, tdigest -> { assertEquals(tdigest.state.size(), 7L); - assertTrue(tdigest.state.centroidCount() <= 7L); + assertTrue(tdigest.state.centroids().size() <= 7L); assertEquals(8.0d, tdigest.percentile(100), 0.0d); assertEquals("8.0", tdigest.percentileAsString(100)); assertEquals(8.0d, tdigest.percentile(88), 0.0d); @@ -156,7 +156,7 @@ public void testQueryFiltering() throws IOException { testCase(LongPoint.newRangeQuery("row", 1, 4), docs, tdigest -> { assertEquals(4L, tdigest.state.size()); - assertEquals(4L, tdigest.state.centroidCount()); + assertEquals(4L, tdigest.state.centroids().size()); assertEquals(2.0d, tdigest.percentile(100), 0.0d); assertEquals(1.0d, tdigest.percentile(50), 0.0d); assertEquals(1.0d, tdigest.percentile(25), 0.0d); @@ -165,7 +165,7 @@ public void testQueryFiltering() throws IOException { testCase(LongPoint.newRangeQuery("row", 100, 110), docs, tdigest -> { assertEquals(0L, tdigest.state.size()); - assertEquals(0L, tdigest.state.centroidCount()); + assertEquals(0L, tdigest.state.centroids().size()); assertFalse(AggregationInspectionHelper.hasValue(tdigest)); }); }