Skip to content

Commit 1b9b48b

Browse files
committed
[ML] [Data Frame] add support for weighted_avg agg (elastic#42646)
1 parent 5a76f46 commit 1b9b48b

File tree

5 files changed

+56
-1
lines changed

5 files changed

+56
-1
lines changed

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,45 @@ public void testPivotWithGeoCentroidAgg() throws Exception {
473473
assertEquals((4 + 15), Double.valueOf(latlon[1]), 0.000001);
474474
}
475475

476+
public void testPivotWithWeightedAvgAgg() throws Exception {
477+
String transformId = "weightedAvgAggTransform";
478+
String dataFrameIndex = "weighted_avg_pivot_reviews";
479+
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
480+
481+
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
482+
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
483+
484+
String config = "{"
485+
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
486+
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
487+
488+
config += " \"pivot\": {"
489+
+ " \"group_by\": {"
490+
+ " \"reviewer\": {"
491+
+ " \"terms\": {"
492+
+ " \"field\": \"user_id\""
493+
+ " } } },"
494+
+ " \"aggregations\": {"
495+
+ " \"avg_rating\": {"
496+
+ " \"weighted_avg\": {"
497+
+ " \"value\": {\"field\": \"stars\"},"
498+
+ " \"weight\": {\"field\": \"stars\"}"
499+
+ "} } } }"
500+
+ "}";
501+
502+
createDataframeTransformRequest.setJsonEntity(config);
503+
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
504+
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
505+
506+
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
507+
assertTrue(indexExists(dataFrameIndex));
508+
509+
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
510+
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
511+
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
512+
assertEquals(4.47169811, actual.doubleValue(), 0.000001);
513+
}
514+
476515
private void assertOnePivotValue(String query, double expected) throws IOException {
477516
Map<String, Object> searchResult = getAsMap(query);
478517

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ enum AggregationType {
3737
SUM("sum", SOURCE),
3838
GEO_CENTROID("geo_centroid", "geo_point"),
3939
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
40+
WEIGHTED_AVG("weighted_avg", DYNAMIC),
4041
BUCKET_SCRIPT("bucket_script", DYNAMIC);
4142

4243
private final String aggregationType;

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.search.aggregations.AggregationBuilder;
1818
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
1919
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
20+
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
2021
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
2122
import org.elasticsearch.xpack.core.ClientHelper;
2223
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
@@ -77,7 +78,7 @@ public static void deduceMappings(final Client client,
7778
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
7879
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
7980
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
80-
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
81+
} else if(agg instanceof ScriptedMetricAggregationBuilder || agg instanceof MultiValuesSourceAggregationBuilder) {
8182
aggregationTypes.put(agg.getName(), agg.getType());
8283
} else {
8384
// execution should not reach this point

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,9 @@ public void testResolveTargetMapping() {
4949
// bucket_script
5050
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null));
5151
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int"));
52+
53+
// weighted_avg
54+
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", null));
55+
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", "double"));
5256
}
5357
}

x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,16 @@ private AggregationConfig getAggregationConfig(String agg) throws IOException {
215215
"\"buckets_path\":{\"param_1\":\"other_bucket\"}," +
216216
"\"script\":\"return params.param_1\"}}}");
217217
}
218+
if (agg.equals(AggregationType.WEIGHTED_AVG.getName())) {
219+
return parseAggregations("{\n" +
220+
"\"pivot_weighted_avg\": {\n" +
221+
" \"weighted_avg\": {\n" +
222+
" \"value\": {\"field\": \"values\"},\n" +
223+
" \"weight\": {\"field\": \"weights\"}\n" +
224+
" }\n" +
225+
"}\n" +
226+
"}");
227+
}
218228
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
219229
+ " }\n" + " }" + "}");
220230
}

0 commit comments

Comments
 (0)