diff --git a/CHANGELOG.md b/CHANGELOG.md index a6951535d4cdc..0bff5ca79df43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Add CompletionStage variants to methods in the Client Interface and default to ActionListener impl ([#18998](https://github.com/opensearch-project/OpenSearch/pull/18998)) - IllegalArgumentException when scroll ID references a node not found in Cluster ([#19031](https://github.com/opensearch-project/OpenSearch/pull/19031)) +- Adding ScriptedAvg class to painless spi to allowlist usage from plugins ([#19006](https://github.com/opensearch-project/OpenSearch/pull/19006)) ### Fixed - Fix unnecessary refreshes on update preparation failures ([#15261](https://github.com/opensearch-project/OpenSearch/issues/15261)) diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java index 6498b618b28c3..d27eb7197213f 100644 --- a/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java @@ -789,6 +789,8 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep o.writeByte((byte) 27); o.writeSemverRange((SemverRange) v); }); + // Have registered ScriptedAvg class with byte 28 in Streamables.java, so that we do not need the implementation reside in the + // server module WRITERS = Collections.unmodifiableMap(writers); } diff --git a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.txt b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.txt index b91d9bb6115d4..37e769f88e010 100644 --- a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.txt +++ b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.txt @@ -164,3 +164,9 @@ class org.opensearch.index.query.IntervalFilterScript$Interval { class org.opensearch.script.ScoreScript$ExplanationHolder { void set(String) } + +class org.opensearch.search.aggregations.metrics.ScriptedAvg { + (double,long) + double getSum() + long getCount() +} diff --git a/server/src/main/java/org/opensearch/common/io/stream/Streamables.java b/server/src/main/java/org/opensearch/common/io/stream/Streamables.java index f1e5f5f22d527..6673276cb6d5f 100644 --- a/server/src/main/java/org/opensearch/common/io/stream/Streamables.java +++ b/server/src/main/java/org/opensearch/common/io/stream/Streamables.java @@ -12,6 +12,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable.WriteableRegistry; +import org.opensearch.search.aggregations.metrics.ScriptedAvg; /** * This utility class registers generic types for streaming over the wire using @@ -45,6 +46,12 @@ private static void registerWriters() { o.writeByte((byte) 22); ((GeoPoint) v).writeTo(o); }); + + WriteableRegistry.registerWriter(ScriptedAvg.class, (o, v) -> { + o.writeByte((byte) 28); + ((ScriptedAvg) v).writeTo(o); + }); + } /** @@ -55,5 +62,6 @@ private static void registerWriters() { private static void registerReaders() { /* {@link GeoPoint} */ WriteableRegistry.registerReader(Byte.valueOf((byte) 22), GeoPoint::new); + WriteableRegistry.registerReader(Byte.valueOf((byte) 28), ScriptedAvg::new); } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalAvg.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalAvg.java index f7839a2b07a7a..c53c7b920d5ea 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalAvg.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalAvg.java @@ -111,19 +111,21 @@ public InternalAvg reduce(List aggregations, ReduceContext for (InternalAggregation aggregation : aggregations) { if (aggregation instanceof InternalScriptedMetric) { // If using InternalScriptedMetric in place of InternalAvg - Object value = ((InternalScriptedMetric) aggregation).aggregation(); - if (value instanceof ScriptedAvg scriptedAvg) { - count += scriptedAvg.getCount(); - kahanSummation.add(scriptedAvg.getSum()); - } else { - throw new IllegalArgumentException( - "Invalid ScriptedMetric result for [" - + getName() - + "] avg aggregation. Expected ScriptedAvg " - + "but received [" - + (value == null ? "null" : value.getClass().getName()) - + "]" - ); + List aggList = ((InternalScriptedMetric) aggregation).aggregationsList(); + for (Object value : aggList) { + if (value instanceof ScriptedAvg scriptedAvg) { + count += scriptedAvg.getCount(); + kahanSummation.add(scriptedAvg.getSum()); + } else { + throw new IllegalArgumentException( + "Invalid ScriptedMetric result for [" + + getName() + + "] avg aggregation. Expected ScriptedAvg " + + "but received [" + + (value == null ? "null" : value.getClass().getName()) + + "]" + ); + } } } else { // Original handling for InternalAvg diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalValueCount.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalValueCount.java index 2b8037488e428..f447e31c94486 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalValueCount.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalValueCount.java @@ -88,17 +88,19 @@ public InternalAggregation reduce(List aggregations, Reduce for (InternalAggregation aggregation : aggregations) { if (aggregation instanceof InternalScriptedMetric) { // If using InternalScriptedMetric in place of InternalValueCount - Object value = ((InternalScriptedMetric) aggregation).aggregation(); - if (value instanceof Number) { - valueCount += ((Number) value).longValue(); - } else { - throw new IllegalArgumentException( - "Invalid ScriptedMetric result for [" - + getName() - + "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got [" - + (value == null ? "null" : value.getClass().getName()) - + "]" - ); + List aggList = ((InternalScriptedMetric) aggregation).aggregationsList(); + for (Object value : aggList) { + if (value instanceof Number) { + valueCount += ((Number) value).longValue(); + } else { + throw new IllegalArgumentException( + "Invalid ScriptedMetric result for [" + + getName() + + "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got [" + + (value == null ? "null" : value.getClass().getName()) + + "]" + ); + } } } else { // Original handling for InternalValueCount diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedAvg.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedAvg.java index c739de94a7cf3..439e06255ead9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedAvg.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedAvg.java @@ -4,28 +4,7 @@ * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ @@ -79,4 +58,5 @@ public double getSum() { public long getCount() { return count; } + } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedMetricAggregator.java index 8e04bbae41107..0ad7f8fb2e8b6 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -158,7 +158,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException { @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { Object result = aggStateForResult(owningBucketOrdinal).combine(); - StreamOutput.checkWriteable(result); + if (result.getClass() != ScriptedAvg.class) StreamOutput.checkWriteable(result); return new InternalScriptedMetric(name, singletonList(result), reduceScript, metadata()); } diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalAvgTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalAvgTests.java index 60aa49018bceb..0c6462482c72a 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalAvgTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalAvgTests.java @@ -127,8 +127,9 @@ public void testReduceWithScriptedMetric() { // Add ScriptedMetric with ScriptedAvg object InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class); when(scriptedMetric1.getName()).thenReturn(name); - ScriptedAvg scriptedAvg = new ScriptedAvg(100.0, 20L); - when(scriptedMetric1.aggregation()).thenReturn(scriptedAvg); + List aggList = new ArrayList<>(); + aggList.add(new ScriptedAvg(100.0, 20L)); + when(scriptedMetric1.aggregationsList()).thenReturn(aggList); aggregations.add(scriptedMetric1); InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null); @@ -175,7 +176,9 @@ public void testReduceWithScriptedMetricInvalidType() { // Add ScriptedMetric with invalid return type (String instead of double[]) InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class); when(scriptedMetric1.getName()).thenReturn(name); - when(scriptedMetric1.aggregation()).thenReturn("invalid_type"); + List aggList = new ArrayList<>(); + aggList.add("invalid_type"); + when(scriptedMetric1.aggregationsList()).thenReturn(aggList); aggregations.add(scriptedMetric1); InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null); @@ -199,7 +202,9 @@ public void testReduceWithScriptedMetricInvalidArrayLength() { // Add ScriptedMetric with double array of wrong length (should be 2) InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class); when(scriptedMetric.getName()).thenReturn(name); - when(scriptedMetric.aggregation()).thenReturn(new double[] { 100.0, 20.0, 30.0 }); // length 3 instead of 2 + List aggList = new ArrayList<>(); + aggList.add(new double[] { 100.0, 20.0, 30.0 }); // Add double array to list + when(scriptedMetric.aggregationsList()).thenReturn(aggList); aggregations.add(scriptedMetric); InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null); diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalValueCountTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalValueCountTests.java index fcfced6ccb0da..8a7d2ab80c33a 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalValueCountTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/InternalValueCountTests.java @@ -71,17 +71,23 @@ public void testReduceWithScriptedMetric() { // Add ScriptedMetric with Long value InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class); - when(scriptedMetric1.aggregation()).thenReturn(20L); + List aggList1 = new ArrayList<>(); + aggList1.add(20L); + when(scriptedMetric1.aggregationsList()).thenReturn(aggList1); aggregations.add(scriptedMetric1); // Add ScriptedMetric with Integer value InternalScriptedMetric scriptedMetric2 = mock(InternalScriptedMetric.class); - when(scriptedMetric2.aggregation()).thenReturn(30); + List aggList2 = new ArrayList<>(); + aggList2.add(30); + when(scriptedMetric2.aggregationsList()).thenReturn(aggList2); aggregations.add(scriptedMetric2); // Add ScriptedMetric with Double value InternalScriptedMetric scriptedMetric3 = mock(InternalScriptedMetric.class); - when(scriptedMetric3.aggregation()).thenReturn(10.5); + List aggList3 = new ArrayList<>(); + aggList3.add(10.5); + when(scriptedMetric3.aggregationsList()).thenReturn(aggList3); aggregations.add(scriptedMetric3); InternalValueCount valueCount = new InternalValueCount(name, 0L, null); @@ -92,6 +98,7 @@ public void testReduceWithScriptedMetric() { } public void testReduceWithInternalValueCountOnly() { + // This test remains unchanged as it doesn't use ScriptedMetric String name = "test_value_count"; List aggregations = new ArrayList<>(); @@ -116,7 +123,9 @@ public void testReduceWithScriptedMetricInvalidValue() { // Add ScriptedMetric with invalid value type (String instead of Number) InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class); - when(scriptedMetric.aggregation()).thenReturn("invalid_value"); + List aggList = new ArrayList<>(); + aggList.add("invalid_value"); + when(scriptedMetric.aggregationsList()).thenReturn(aggList); aggregations.add(scriptedMetric); InternalValueCount valueCount = new InternalValueCount(name, 0L, null); @@ -133,6 +142,29 @@ public void testReduceWithScriptedMetricInvalidValue() { ); } + public void testReduceWithMultipleValuesInList() { + String name = "test_scripted_metric"; + List aggregations = new ArrayList<>(); + + // Add regular InternalValueCount + aggregations.add(new InternalValueCount(name, 50L, null)); + + // Add ScriptedMetric with multiple values in the list + InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class); + List aggList = new ArrayList<>(); + aggList.add(20L); + aggList.add(30); + aggList.add(10.5); + when(scriptedMetric.aggregationsList()).thenReturn(aggList); + aggregations.add(scriptedMetric); + + InternalValueCount valueCount = new InternalValueCount(name, 0L, null); + InternalValueCount reduced = (InternalValueCount) valueCount.reduce(aggregations, null); + + // Expected: 50 + 20 + 30 + 10 = 110 + assertEquals(110L, reduced.getValue()); + } + @Override protected InternalValueCount mutateInstance(InternalValueCount instance) { String name = instance.getName();