diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java index fb66fd1a5593a..2587716c51dd0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java @@ -89,6 +89,11 @@ public String describe() { private final boolean isRateOverTime; private final double dateFactor; + // tracking min/max group ids to allow flushing the raw buffer when the slice index changed + private int minRawInputGroupId = Integer.MAX_VALUE; + private int maxRawInputGroupId = Integer.MIN_VALUE; + private int lastSliceIndex = -1; + public RateDoubleGroupingAggregatorFunction( List channels, DriverContext driverContext, @@ -151,7 +156,11 @@ public void close() { assert sliceIndices != null : "expected slice indices vector in time-series aggregation"; LongVector futureMaxTimestamps = ((LongBlock) page.getBlock(channels.get(3))).asVector(); assert futureMaxTimestamps != null : "expected future max timestamps vector in time-series aggregation"; - + int sliceIndex = sliceIndices.getInt(0); + if (sliceIndex > lastSliceIndex) { + flushRawBuffers(); + lastSliceIndex = sliceIndex; + } return new AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { @@ -400,12 +409,36 @@ private Buffer getBuffer(int groupId, int newElements, long firstTimestamp) { if (buffer == null) { buffer = new Buffer(bigArrays, newElements); buffers.set(groupId, buffer); + minRawInputGroupId = Math.min(minRawInputGroupId, groupId); + maxRawInputGroupId = Math.max(maxRawInputGroupId, groupId); } else { buffer.ensureCapacity(bigArrays, newElements, firstTimestamp); } return buffer; } + void flushRawBuffers() { + if (minRawInputGroupId > maxRawInputGroupId) { + return; + } + reducedStates = bigArrays.grow(reducedStates, maxRawInputGroupId + 1); + for (int groupId = minRawInputGroupId; groupId <= maxRawInputGroupId; groupId++) { + Buffer buffer = buffers.getAndSet(groupId, null); + if (buffer != null) { + try (buffer) { + ReducedState state = reducedStates.get(groupId); + if (state == null) { + state = new ReducedState(); + reducedStates.set(groupId, state); + } + buffer.flush(state); + } + } + } + minRawInputGroupId = Integer.MAX_VALUE; + maxRawInputGroupId = Integer.MIN_VALUE; + } + /** * Buffers data points in two arrays: one for timestamps and one for values, partitioned into multiple slices. * Each slice is sorted in descending order of timestamp. A new slice is created when a data point has a diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java index 9ee9527f1f999..68ce05ee48515 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java @@ -89,6 +89,11 @@ public String describe() { private final boolean isRateOverTime; private final double dateFactor; + // tracking min/max group ids to allow flushing the raw buffer when the slice index changed + private int minRawInputGroupId = Integer.MAX_VALUE; + private int maxRawInputGroupId = Integer.MIN_VALUE; + private int lastSliceIndex = -1; + public RateIntGroupingAggregatorFunction( List channels, DriverContext driverContext, @@ -151,7 +156,11 @@ public void close() { assert sliceIndices != null : "expected slice indices vector in time-series aggregation"; LongVector futureMaxTimestamps = ((LongBlock) page.getBlock(channels.get(3))).asVector(); assert futureMaxTimestamps != null : "expected future max timestamps vector in time-series aggregation"; - + int sliceIndex = sliceIndices.getInt(0); + if (sliceIndex > lastSliceIndex) { + flushRawBuffers(); + lastSliceIndex = sliceIndex; + } return new AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { @@ -400,12 +409,36 @@ private Buffer getBuffer(int groupId, int newElements, long firstTimestamp) { if (buffer == null) { buffer = new Buffer(bigArrays, newElements); buffers.set(groupId, buffer); + minRawInputGroupId = Math.min(minRawInputGroupId, groupId); + maxRawInputGroupId = Math.max(maxRawInputGroupId, groupId); } else { buffer.ensureCapacity(bigArrays, newElements, firstTimestamp); } return buffer; } + void flushRawBuffers() { + if (minRawInputGroupId > maxRawInputGroupId) { + return; + } + reducedStates = bigArrays.grow(reducedStates, maxRawInputGroupId + 1); + for (int groupId = minRawInputGroupId; groupId <= maxRawInputGroupId; groupId++) { + Buffer buffer = buffers.getAndSet(groupId, null); + if (buffer != null) { + try (buffer) { + ReducedState state = reducedStates.get(groupId); + if (state == null) { + state = new ReducedState(); + reducedStates.set(groupId, state); + } + buffer.flush(state); + } + } + } + minRawInputGroupId = Integer.MAX_VALUE; + maxRawInputGroupId = Integer.MIN_VALUE; + } + /** * Buffers data points in two arrays: one for timestamps and one for values, partitioned into multiple slices. * Each slice is sorted in descending order of timestamp. A new slice is created when a data point has a diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java index 49d34a2d200cb..42962db7af0fe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java @@ -89,6 +89,11 @@ public String describe() { private final boolean isRateOverTime; private final double dateFactor; + // tracking min/max group ids to allow flushing the raw buffer when the slice index changed + private int minRawInputGroupId = Integer.MAX_VALUE; + private int maxRawInputGroupId = Integer.MIN_VALUE; + private int lastSliceIndex = -1; + public RateLongGroupingAggregatorFunction( List channels, DriverContext driverContext, @@ -151,7 +156,11 @@ public void close() { assert sliceIndices != null : "expected slice indices vector in time-series aggregation"; LongVector futureMaxTimestamps = ((LongBlock) page.getBlock(channels.get(3))).asVector(); assert futureMaxTimestamps != null : "expected future max timestamps vector in time-series aggregation"; - + int sliceIndex = sliceIndices.getInt(0); + if (sliceIndex > lastSliceIndex) { + flushRawBuffers(); + lastSliceIndex = sliceIndex; + } return new AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { @@ -400,12 +409,36 @@ private Buffer getBuffer(int groupId, int newElements, long firstTimestamp) { if (buffer == null) { buffer = new Buffer(bigArrays, newElements); buffers.set(groupId, buffer); + minRawInputGroupId = Math.min(minRawInputGroupId, groupId); + maxRawInputGroupId = Math.max(maxRawInputGroupId, groupId); } else { buffer.ensureCapacity(bigArrays, newElements, firstTimestamp); } return buffer; } + void flushRawBuffers() { + if (minRawInputGroupId > maxRawInputGroupId) { + return; + } + reducedStates = bigArrays.grow(reducedStates, maxRawInputGroupId + 1); + for (int groupId = minRawInputGroupId; groupId <= maxRawInputGroupId; groupId++) { + Buffer buffer = buffers.getAndSet(groupId, null); + if (buffer != null) { + try (buffer) { + ReducedState state = reducedStates.get(groupId); + if (state == null) { + state = new ReducedState(); + reducedStates.set(groupId, state); + } + buffer.flush(state); + } + } + } + minRawInputGroupId = Integer.MAX_VALUE; + maxRawInputGroupId = Integer.MIN_VALUE; + } + /** * Buffers data points in two arrays: one for timestamps and one for values, partitioned into multiple slices. * Each slice is sorted in descending order of timestamp. A new slice is created when a data point has a diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateGroupingAggregatorFunction.java.st index 47f6c1e164871..c830b1ac0e7ad 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateGroupingAggregatorFunction.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateGroupingAggregatorFunction.java.st @@ -89,6 +89,11 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre private final boolean isRateOverTime; private final double dateFactor; + // tracking min/max group ids to allow flushing the raw buffer when the slice index changed + private int minRawInputGroupId = Integer.MAX_VALUE; + private int maxRawInputGroupId = Integer.MIN_VALUE; + private int lastSliceIndex = -1; + public Rate$Type$GroupingAggregatorFunction( List channels, DriverContext driverContext, @@ -151,7 +156,11 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre assert sliceIndices != null : "expected slice indices vector in time-series aggregation"; LongVector futureMaxTimestamps = ((LongBlock) page.getBlock(channels.get(3))).asVector(); assert futureMaxTimestamps != null : "expected future max timestamps vector in time-series aggregation"; - + int sliceIndex = sliceIndices.getInt(0); + if (sliceIndex > lastSliceIndex) { + flushRawBuffers(); + lastSliceIndex = sliceIndex; + } return new AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { @@ -400,12 +409,36 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre if (buffer == null) { buffer = new Buffer(bigArrays, newElements); buffers.set(groupId, buffer); + minRawInputGroupId = Math.min(minRawInputGroupId, groupId); + maxRawInputGroupId = Math.max(maxRawInputGroupId, groupId); } else { buffer.ensureCapacity(bigArrays, newElements, firstTimestamp); } return buffer; } + void flushRawBuffers() { + if (minRawInputGroupId > maxRawInputGroupId) { + return; + } + reducedStates = bigArrays.grow(reducedStates, maxRawInputGroupId + 1); + for (int groupId = minRawInputGroupId; groupId <= maxRawInputGroupId; groupId++) { + Buffer buffer = buffers.getAndSet(groupId, null); + if (buffer != null) { + try (buffer) { + ReducedState state = reducedStates.get(groupId); + if (state == null) { + state = new ReducedState(); + reducedStates.set(groupId, state); + } + buffer.flush(state); + } + } + } + minRawInputGroupId = Integer.MAX_VALUE; + maxRawInputGroupId = Integer.MIN_VALUE; + } + /** * Buffers data points in two arrays: one for timestamps and one for values, partitioned into multiple slices. * Each slice is sorted in descending order of timestamp. A new slice is created when a data point has a diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..d512350d897d5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunctionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.HashAggregationOperator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.OperatorTestCase; +import org.elasticsearch.compute.test.TestDriverFactory; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class RateDoubleGroupingAggregatorFunctionTests extends ComputeTestCase { + protected final DriverContext driverContext() { + BlockFactory blockFactory = blockFactory(); + return new DriverContext(blockFactory.bigArrays(), blockFactory); + } + + public void testFlushOnSliceChanged() { + DriverContext driverContext = driverContext(); + List pages = new ArrayList<>(); + int numIntervals = between(1, 10); + record Interval(long t1, double v1, long t2, double v2) {} + List intervals = new ArrayList<>(); + for (int interval = 0; interval < numIntervals; interval++) { + int positions = between(1, 100); + long timestamp = between(1, 1000); + long value = between(1, 10); + long[] values = new long[positions]; + long[] timestamps = new long[positions]; + for (int p = 0; p < positions; p++) { + values[p] = value; + timestamps[p] = timestamp; + value += between(1, 10); + timestamp += between(1, 10); + } + intervals.add(new Interval(timestamps[positions - 1], values[positions - 1], timestamps[0], values[0])); + BlockFactory blockFactory = blockFactory(); + try ( + var valuesBuilder = blockFactory.newDoubleBlockBuilder(positions); + var timestampsBuilder = blockFactory.newLongBlockBuilder(positions) + ) { + for (int p = 0; p < positions; p++) { + valuesBuilder.appendDouble(values[positions - p - 1]); + timestampsBuilder.appendLong(timestamps[positions - p - 1]); + } + pages.add( + new Page( + blockFactory.newConstantIntBlockWith(0, positions), + valuesBuilder.build(), + timestampsBuilder.build(), + blockFactory.newConstantIntBlockWith(interval, positions), + blockFactory.newConstantLongBlockWith(Long.MAX_VALUE, positions) + ) + ); + } + } + // values, timestamps, slice, future_timestamps + var aggregatorFactory = new RateDoubleGroupingAggregatorFunction.FunctionSupplier(false, false).groupingAggregatorFactory( + AggregatorMode.INITIAL, + List.of(1, 2, 3, 4) + ); + final List groupSpecs = List.of(new BlockHash.GroupSpec(0, ElementType.INT)); + HashAggregationOperator hashAggregationOperator = new HashAggregationOperator( + List.of(aggregatorFactory), + () -> BlockHash.build(groupSpecs, driverContext.blockFactory(), randomIntBetween(1, 1024), randomBoolean()), + driverContext + ); + List outputPages = new ArrayList<>(); + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(pages.iterator()), + List.of(hashAggregationOperator), + new PageConsumerOperator(outputPages::add) + ); + OperatorTestCase.runDriver(driver); + for (Page out : outputPages) { + assertThat(out.getPositionCount(), equalTo(1)); + LongBlock timestamps = out.getBlock(1); + DoubleBlock values = out.getBlock(2); + assertThat(values.getValueCount(0), equalTo(numIntervals * 2)); + assertThat(timestamps.getValueCount(0), equalTo(numIntervals * 2)); + for (int i = 0; i < numIntervals; i++) { + Interval interval = intervals.get(i); + assertThat(timestamps.getLong(2 * i), equalTo(interval.t1)); + assertThat(values.getDouble(2 * i), equalTo(interval.v1)); + assertThat(timestamps.getLong(2 * i + 1), equalTo(interval.t2)); + assertThat(values.getDouble(2 * i + 1), equalTo(interval.v2)); + } + out.close(); + } + } +}