Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added

### Changed
- Adding ScriptedAvg class to painless spi to allowlist usage from plugins ([#19006](https://github.com/opensearch-project/OpenSearch/pull/19006))

### Fixed
- Fix unnecessary refreshes on update preparation failures ([#15261](https://github.com/opensearch-project/OpenSearch/issues/15261))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ class org.opensearch.index.query.IntervalFilterScript$Interval {
class org.opensearch.script.ScoreScript$ExplanationHolder {
void set(String)
}

class org.opensearch.search.aggregations.metrics.ScriptedAvg {
(double,long)
double getSum()
long getCount()
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable.WriteableRegistry;
import org.opensearch.search.aggregations.metrics.ScriptedAvg;

/**
* This utility class registers generic types for streaming over the wire using
Expand Down Expand Up @@ -45,6 +46,12 @@ private static void registerWriters() {
o.writeByte((byte) 22);
((GeoPoint) v).writeTo(o);
});

WriteableRegistry.registerWriter(ScriptedAvg.class, (o, v) -> {
o.writeByte((byte) 28);
((ScriptedAvg) v).writeTo(o);
});

}

/**
Expand All @@ -55,5 +62,6 @@ private static void registerWriters() {
private static void registerReaders() {
/* {@link GeoPoint} */
WriteableRegistry.registerReader(Byte.valueOf((byte) 22), GeoPoint::new);
WriteableRegistry.registerReader(Byte.valueOf((byte) 28), ScriptedAvg::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,21 @@ public InternalAvg reduce(List<InternalAggregation> aggregations, ReduceContext
for (InternalAggregation aggregation : aggregations) {
if (aggregation instanceof InternalScriptedMetric) {
// If using InternalScriptedMetric in place of InternalAvg
Object value = ((InternalScriptedMetric) aggregation).aggregation();
if (value instanceof ScriptedAvg scriptedAvg) {
count += scriptedAvg.getCount();
kahanSummation.add(scriptedAvg.getSum());
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] avg aggregation. Expected ScriptedAvg "
+ "but received ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
List<Object> aggList = ((InternalScriptedMetric) aggregation).aggregationsList();
for (Object value : aggList) {
if (value instanceof ScriptedAvg scriptedAvg) {
count += scriptedAvg.getCount();
kahanSummation.add(scriptedAvg.getSum());
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] avg aggregation. Expected ScriptedAvg "
+ "but received ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
}
}
} else {
// Original handling for InternalAvg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
for (InternalAggregation aggregation : aggregations) {
if (aggregation instanceof InternalScriptedMetric) {
// If using InternalScriptedMetric in place of InternalValueCount
Object value = ((InternalScriptedMetric) aggregation).aggregation();
if (value instanceof Number) {
valueCount += ((Number) value).longValue();
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
List<Object> aggList = ((InternalScriptedMetric) aggregation).aggregationsList();
for (Object value : aggList) {
if (value instanceof Number) {
valueCount += ((Number) value).longValue();
} else {
throw new IllegalArgumentException(
"Invalid ScriptedMetric result for ["
+ getName()
+ "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got ["
+ (value == null ? "null" : value.getClass().getName())
+ "]"
);
}
}
} else {
// Original handling for InternalValueCount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,7 @@
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
Expand All @@ -35,6 +14,9 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;

Expand All @@ -43,7 +25,7 @@
*
* @opensearch.internal
*/
public class ScriptedAvg implements Writeable {
public class ScriptedAvg implements Writeable, ToXContent, ToXContentFragment {
private double sum;
private long count;

Expand Down Expand Up @@ -79,4 +61,14 @@ public double getSum() {
public long getCount() {
return count;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("sum", sum);
builder.field("count", count);
builder.endObject();
return builder;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
@Override
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object result = aggStateForResult(owningBucketOrdinal).combine();
StreamOutput.checkWriteable(result);
if (result.getClass() != ScriptedAvg.class) StreamOutput.checkWriteable(result);
return new InternalScriptedMetric(name, singletonList(result), reduceScript, metadata());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ public void testReduceWithScriptedMetric() {
// Add ScriptedMetric with ScriptedAvg object
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.getName()).thenReturn(name);
ScriptedAvg scriptedAvg = new ScriptedAvg(100.0, 20L);
when(scriptedMetric1.aggregation()).thenReturn(scriptedAvg);
List<Object> aggList = new ArrayList<>();
aggList.add(new ScriptedAvg(100.0, 20L));
when(scriptedMetric1.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric1);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand Down Expand Up @@ -175,7 +176,9 @@ public void testReduceWithScriptedMetricInvalidType() {
// Add ScriptedMetric with invalid return type (String instead of double[])
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.getName()).thenReturn(name);
when(scriptedMetric1.aggregation()).thenReturn("invalid_type");
List<Object> aggList = new ArrayList<>();
aggList.add("invalid_type");
when(scriptedMetric1.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric1);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand All @@ -199,7 +202,9 @@ public void testReduceWithScriptedMetricInvalidArrayLength() {
// Add ScriptedMetric with double array of wrong length (should be 2)
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
when(scriptedMetric.getName()).thenReturn(name);
when(scriptedMetric.aggregation()).thenReturn(new double[] { 100.0, 20.0, 30.0 }); // length 3 instead of 2
List<Object> aggList = new ArrayList<>();
aggList.add(new double[] { 100.0, 20.0, 30.0 }); // Add double array to list
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,23 @@ public void testReduceWithScriptedMetric() {

// Add ScriptedMetric with Long value
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
when(scriptedMetric1.aggregation()).thenReturn(20L);
List<Object> aggList1 = new ArrayList<>();
aggList1.add(20L);
when(scriptedMetric1.aggregationsList()).thenReturn(aggList1);
aggregations.add(scriptedMetric1);

// Add ScriptedMetric with Integer value
InternalScriptedMetric scriptedMetric2 = mock(InternalScriptedMetric.class);
when(scriptedMetric2.aggregation()).thenReturn(30);
List<Object> aggList2 = new ArrayList<>();
aggList2.add(30);
when(scriptedMetric2.aggregationsList()).thenReturn(aggList2);
aggregations.add(scriptedMetric2);

// Add ScriptedMetric with Double value
InternalScriptedMetric scriptedMetric3 = mock(InternalScriptedMetric.class);
when(scriptedMetric3.aggregation()).thenReturn(10.5);
List<Object> aggList3 = new ArrayList<>();
aggList3.add(10.5);
when(scriptedMetric3.aggregationsList()).thenReturn(aggList3);
aggregations.add(scriptedMetric3);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
Expand All @@ -92,6 +98,7 @@ public void testReduceWithScriptedMetric() {
}

public void testReduceWithInternalValueCountOnly() {
// This test remains unchanged as it doesn't use ScriptedMetric
String name = "test_value_count";
List<InternalAggregation> aggregations = new ArrayList<>();

Expand All @@ -116,7 +123,9 @@ public void testReduceWithScriptedMetricInvalidValue() {

// Add ScriptedMetric with invalid value type (String instead of Number)
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
when(scriptedMetric.aggregation()).thenReturn("invalid_value");
List<Object> aggList = new ArrayList<>();
aggList.add("invalid_value");
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
Expand All @@ -133,6 +142,29 @@ public void testReduceWithScriptedMetricInvalidValue() {
);
}

public void testReduceWithMultipleValuesInList() {
String name = "test_scripted_metric";
List<InternalAggregation> aggregations = new ArrayList<>();

// Add regular InternalValueCount
aggregations.add(new InternalValueCount(name, 50L, null));

// Add ScriptedMetric with multiple values in the list
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
List<Object> aggList = new ArrayList<>();
aggList.add(20L);
aggList.add(30);
aggList.add(10.5);
when(scriptedMetric.aggregationsList()).thenReturn(aggList);
aggregations.add(scriptedMetric);

InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
InternalValueCount reduced = (InternalValueCount) valueCount.reduce(aggregations, null);

// Expected: 50 + 20 + 30 + 10 = 110
assertEquals(110L, reduced.getValue());
}

@Override
protected InternalValueCount mutateInstance(InternalValueCount instance) {
String name = instance.getName();
Expand Down
Loading