From 8251782b9501b052f6adfb3d747611728f150a50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Mar 2026 15:02:14 +0100 Subject: [PATCH 01/22] ESQL: Added GroupedTopNOperator for LIMIT BY, compute only --- .../org/elasticsearch/core/Releasables.java | 2 +- .../org/elasticsearch/test/ESTestCase.java | 5 + .../compute/operator/PositionKeyEncoder.java | 146 ++ .../compute/operator/topn/GroupedQueue.java | 209 +++ .../compute/operator/topn/GroupedRow.java | 114 ++ .../operator/topn/GroupedRowFiller.java | 90 ++ .../operator/topn/GroupedTopNOperator.java | 401 ++++++ .../topn/GroupedTopNOperatorStatus.java | 194 +++ .../operator/topn/GroupedQueueTests.java | 246 ++++ .../operator/topn/GroupedRowTests.java | 88 ++ .../topn/GroupedTopNOperatorStatusTests.java | 118 ++ .../topn/GroupedTopNOperatorTests.java | 683 ++++++++++ .../operator/topn/TopNOperatorTests.java | 1195 ++++++++++++----- .../compute/test/TestBlockBuilder.java | 68 + .../test/TypedAbstractBlockSourceBuilder.java | 22 + .../ListRowsBlockSourceOperator.java | 10 +- .../TupleAbstractBlockSourceOperator.java | 4 +- 17 files changed, 3284 insertions(+), 311 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java create mode 100644 x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java diff --git a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java index 385258b5c10c3..8618d3134b8e4 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java +++ b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java @@ -67,7 +67,7 @@ public static void closeExpectNoException(Releasable... releasables) { } /** Release the provided {@link Releasable} expecting no exception to by thrown. */ - public static void closeExpectNoException(Releasable releasable) { + public static void closeExpectNoException(@Nullable Releasable releasable) { try { close(releasable); } catch (RuntimeException e) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index d7d7ffae3d414..ad64dc94509f0 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1190,6 +1190,11 @@ public static long randomLong() { return random().nextLong(); } + /** A random long from 0..max (inclusive). */ + public static long randomLong(long max) { + return RandomNumbers.randomLongBetween(random(), 0L, max); + } + public static LongStream randomLongs() { return random().longs(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java new file mode 100644 index 0000000000000..eb1a71553ac71 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; + +import java.util.List; + +/** + * Encodes the values at a given position across multiple blocks into a single {@link BytesRef} composite key. + * Multivalued positions are serialized with list semantics: the value count is written first, then each value + * in block iteration order. This means {@code [1, 2]} and {@code [2, 1]} produce different keys. + * Null positions are encoded as a value count of zero. + */ +public class PositionKeyEncoder implements Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(PositionKeyEncoder.class); + + private final int[] groupChannels; + private final ElementType[] elementTypes; + private final BytesRefBuilder scratch = new BytesRefBuilder(); + private final BytesRef scratchBytesRef = new BytesRef(); + + public PositionKeyEncoder(int[] groupChannels, List elementTypes) { + this.groupChannels = groupChannels; + this.elementTypes = new ElementType[groupChannels.length]; + for (int i = 0; i < groupChannels.length; i++) { + this.elementTypes[i] = elementTypes.get(groupChannels[i]); + } + } + + /** + * Encode the group key for the given position from the page into a {@link BytesRef}. + * The returned reference is only valid until the next call to {@code encode}. + */ + public BytesRef encode(Page page, int position) { + scratch.clear(); + for (int i = 0; i < groupChannels.length; i++) { + Block block = page.getBlock(groupChannels[i]); + encodeBlock(block, elementTypes[i], position); + } + return scratch.get(); + } + + private void encodeBlock(Block block, ElementType type, int position) { + if (block.isNull(position)) { + writeVInt(0); + return; + } + int firstValueIndex = block.getFirstValueIndex(position); + int valueCount = block.getValueCount(position); + writeVInt(valueCount); + switch (type) { + case INT -> { + IntBlock b = (IntBlock) block; + for (int v = 0; v < valueCount; v++) { + writeInt(b.getInt(firstValueIndex + v)); + } + } + case LONG -> { + LongBlock b = (LongBlock) block; + for (int v = 0; v < valueCount; v++) { + writeLong(b.getLong(firstValueIndex + v)); + } + } + case DOUBLE -> { + DoubleBlock b = (DoubleBlock) block; + for (int v = 0; v < valueCount; v++) { + writeLong(Double.doubleToLongBits(b.getDouble(firstValueIndex + v))); + } + } + case FLOAT -> { + FloatBlock b = (FloatBlock) block; + for (int v = 0; v < valueCount; v++) { + writeInt(Float.floatToIntBits(b.getFloat(firstValueIndex + v))); + } + } + case BOOLEAN -> { + BooleanBlock b = (BooleanBlock) block; + for (int v = 0; v < valueCount; v++) { + scratch.append((byte) (b.getBoolean(firstValueIndex + v) ? 1 : 0)); + } + } + case BYTES_REF -> { + BytesRefBlock b = (BytesRefBlock) block; + for (int v = 0; v < valueCount; v++) { + BytesRef ref = b.getBytesRef(firstValueIndex + v, scratchBytesRef); + writeVInt(ref.length); + scratch.append(ref.bytes, ref.offset, ref.length); + } + } + case NULL -> { + // already handled by isNull above; nothing extra to write + } + default -> throw new IllegalArgumentException("unsupported element type for group key encoding: " + type); + } + } + + private void writeVInt(int value) { + while ((value & ~0x7F) != 0) { + scratch.append((byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + scratch.append((byte) value); + } + + private void writeInt(int value) { + scratch.append((byte) (value >> 24)); + scratch.append((byte) (value >> 16)); + scratch.append((byte) (value >> 8)); + scratch.append((byte) value); + } + + private void writeLong(long value) { + writeInt((int) (value >> 32)); + writeInt((int) value); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(groupChannels); + size += RamUsageEstimator.shallowSizeOf(elementTypes); + size += RamUsageEstimator.shallowSizeOfInstance(BytesRefBuilder.class); + size += RamUsageEstimator.sizeOf(scratch.bytes()); + size += RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); + return size; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java new file mode 100644 index 0000000000000..d0b0718bf5ceb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +/** + * A queue that maintains a separate per-group priority queue, indexed by integer group IDs + * assigned by a {@link org.elasticsearch.compute.aggregation.blockhash.BlockHash}. + * Uses a {@link BigArrays}-backed {@link ObjectArray} for better performance and circuit + * breaker integration. + */ +class GroupedQueue implements Accountable, Releasable { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(GroupedQueue.class); + + private final CircuitBreaker breaker; + private final BigArrays bigArrays; + private final int topCount; + private ObjectArray queues; + + GroupedQueue(CircuitBreaker breaker, BigArrays bigArrays, int topCount) { + this.breaker = breaker; + this.bigArrays = bigArrays; + this.topCount = topCount; + this.queues = bigArrays.newObjectArray(0); + } + + @Override + public String toString() { + return size() + "/" + queues.size() + "/" + topCount; + } + + int size() { + int totalSize = 0; + for (long i = 0; i < queues.size(); i++) { + PerGroupQueue queue = queues.get(i); + if (queue != null) { + totalSize += queue.size(); + } + } + return totalSize; + } + + /** + * Attempts to add the row to the appropriate per-group queue based on {@link GroupedRow#groupId}. + * @return If the row was added and the queue was full, the evicted row. + * If the row was added and it wasn't full, {@code null}. + * If the row wasn't added, the input row. + */ + GroupedRow addRow(GroupedRow row) { + return getOrCreateQueue(row.groupId).addRow(row); + } + + private PerGroupQueue getOrCreateQueue(long groupId) { + if (groupId >= queues.size()) { + queues = bigArrays.grow(queues, groupId + 1); + } + PerGroupQueue queue = queues.get(groupId); + if (queue == null) { + queue = PerGroupQueue.build(breaker, topCount); + queues.set(groupId, queue); + } + return queue; + } + + /** + * Removes and returns all rows from all per-group queues. + * For an ascending order, the first element will be the min element (or last in the + * priority queue), and vice versa. + */ + List popAll() { + List allRows = new ArrayList<>(size()); + for (long i = 0; i < queues.size(); i++) { + PerGroupQueue queue = queues.get(i); + if (queue != null) { + queue.popAllInto(allRows); + queue.close(); + queues.set(i, null); + } + } + allRows.sort((r1, r2) -> -r1.compareTo(r2)); + return allRows; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + if (queues != null) { + total += queues.ramBytesUsed(); + for (long i = 0; i < queues.size(); i++) { + PerGroupQueue queue = queues.get(i); + if (queue != null) { + total += queue.ramBytesUsed(); + } + } + } + return total; + } + + @Override + public void close() { + Releasables.close(() -> { + if (queues != null) { + for (long i = 0; i < queues.size(); i++) { + PerGroupQueue queue = queues.get(i); + if (queue != null) { + queue.close(); + queues.set(i, null); + } + } + } + }, queues); + } + + /** + * A single-group priority queue backed by Lucene's PriorityQueue. + */ + static final class PerGroupQueue extends PriorityQueue implements Accountable, Releasable { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(PerGroupQueue.class); + + private final CircuitBreaker breaker; + private final int topCount; + + private PerGroupQueue(CircuitBreaker breaker, int topCount) { + super(topCount); + this.topCount = topCount; + this.breaker = breaker; + } + + static PerGroupQueue build(CircuitBreaker breaker, int topCount) { + breaker.addEstimateBytesAndMaybeBreak(sizeOf(topCount), "topn"); + return new PerGroupQueue(breaker, topCount); + } + + @Override + protected boolean lessThan(GroupedRow lhs, GroupedRow rhs) { + return lhs.compareTo(rhs) < 0; + } + + GroupedRow addRow(GroupedRow row) { + if (size() < topCount) { + add(row); + return null; + } else if (lessThan(top(), row)) { + GroupedRow evicted = top(); + updateTop(row); + return evicted; + } + return row; + } + + void popAllInto(List target) { + while (size() > 0) { + target.add(pop()); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) + ); + for (GroupedRow r : this) { + total += r == null ? 0 : r.ramBytesUsed(); + } + return total; + } + + @Override + public void close() { + Releasables.close(() -> { + var heapArray = getHeapArray(); + for (int i = 0; i < heapArray.length; i++) { + GroupedRow row = (GroupedRow) heapArray[i]; + if (row != null) { + row.close(); + heapArray[i] = null; + } + } + }, () -> breaker.addWithoutBreaking(-sizeOf(topCount))); + } + + private static long sizeOf(int topCount) { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * (topCount + 1L) + ); + return total; + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java new file mode 100644 index 0000000000000..8c139f21c00ed --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.Arrays; + +/** + * A row that belongs to a group, identified by an integer group ID. + * Stores encoded sort keys and values for a single row within a grouped top-N operation. + */ +final class GroupedRow implements Accountable, Comparable, Releasable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedRow.class); + + private final CircuitBreaker breaker; + + /** + * The sort keys, encoded into bytes so we can sort by calling {@link Arrays#compareUnsigned}. + */ + private final BreakingBytesRefBuilder keys; + + /** + * Values to reconstruct the row. When we reconstruct the row we read from both the + * {@link #keys} and the {@link #values}. So this only contains what is required to + * reconstruct the row that isn't already stored in {@link #keys}. + */ + private final BreakingBytesRefBuilder values; + + /** + * Reference counter for the shard this row belongs to, used for rows containing a DocVector + * to ensure the shard context lives until we build the final result. + */ + @Nullable + private RefCounted shardRefCounter; + + /** + * The group ID this row belongs to. + */ + long groupId = -1; + + GroupedRow(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { + breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "GroupedRow"); + this.breaker = breaker; + boolean success = false; + try { + keys = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedKeysSize); + values = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedValueSize); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + BreakingBytesRefBuilder keys() { + return keys; + } + + BreakingBytesRefBuilder values() { + return values; + } + + void setShardRefCounted(RefCounted shardRefCounted) { + if (this.shardRefCounter != null) { + this.shardRefCounter.decRef(); + } + this.shardRefCounter = shardRefCounted; + this.shardRefCounter.mustIncRef(); + } + + void clear() { + keys.clear(); + values.clear(); + clearRefCounters(); + groupId = -1; + } + + private void clearRefCounters() { + if (shardRefCounter != null) { + shardRefCounter.decRef(); + } + shardRefCounter = null; + } + + @Override + public int compareTo(GroupedRow other) { + return -keys.bytesRefView().compareTo(other.keys.bytesRefView()); + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); + } + + @Override + public void close() { + clearRefCounters(); + Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java new file mode 100644 index 0000000000000..c8c53bbb156ac --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; + +import java.util.List; + +/** + * Fills {@link GroupedRow}s from page data for grouped top-N. Handles both sort-key encoding + * and value extraction. The group ID is set directly by the caller from the BlockHash callback. + */ +final class GroupedRowFiller { + private final ValueExtractor[] valueExtractors; + private final KeyExtractor[] sortKeyExtractors; + + private int keyPreAllocSize = 0; + private int valuePreAllocSize = 0; + + GroupedRowFiller( + List elementTypes, + List encoders, + List sortOrders, + boolean[] channelInKey, + Page page + ) { + valueExtractors = new ValueExtractor[page.getBlockCount()]; + for (int b = 0; b < valueExtractors.length; b++) { + valueExtractors[b] = ValueExtractor.extractorFor( + elementTypes.get(b), + encoders.get(b).toUnsortable(), + channelInKey[b], + page.getBlock(b) + ); + } + sortKeyExtractors = new KeyExtractor[sortOrders.size()]; + for (int k = 0; k < sortKeyExtractors.length; k++) { + TopNOperator.SortOrder so = sortOrders.get(k); + sortKeyExtractors[k] = KeyExtractor.extractorFor( + elementTypes.get(so.channel()), + encoders.get(so.channel()), + so.asc(), + so.nul(), + so.nonNul(), + page.getBlock(so.channel()) + ); + } + } + + int preAllocatedKeysSize() { + return keyPreAllocSize; + } + + int preAllocatedValueSize() { + return valuePreAllocSize; + } + + void writeSortKey(int position, GroupedRow row) { + for (KeyExtractor keyExtractor : sortKeyExtractors) { + keyExtractor.writeKey(row.keys(), position); + } + keyPreAllocSize = newPreAllocSize(row.keys(), keyPreAllocSize); + } + + void writeValues(int position, GroupedRow row) { + for (ValueExtractor e : valueExtractors) { + var refCounted = e.getRefCountedForShard(position); + if (refCounted != null) { + row.setShardRefCounted(refCounted); + } + e.writeValue(row.values(), position); + } + valuePreAllocSize = newPreAllocSize(row.values(), valuePreAllocSize); + } + + /** + * Pre-allocation size heuristic: use the larger of the current builder length and half + * the previous pre-alloc size, so the size decays after a single unusually large row. + */ + private static int newPreAllocSize(BreakingBytesRefBuilder builder, int sparePreAllocSize) { + return Math.max(builder.length(), sparePreAllocSize / 2); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java new file mode 100644 index 0000000000000..4a69f5d337430 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -0,0 +1,401 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BytesRefHashTable; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.blockhash.HashImplFactory; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PositionKeyEncoder; +import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.core.Releasables; + +import java.util.Arrays; +import java.util.List; + +/** + * A top-N operator for grouped (SORT + LIMIT BY) queries. Maintains per-group priority queues + * using a {@link PositionKeyEncoder} to map group key columns to integer group IDs. + *

+ * Group keys use list semantics for multivalues: {@code [1,2]} and {@code [2,1]} are different groups. + *

+ * Unlike {@link TopNOperator}, this operator does not support sorted input optimization + * or {@link SharedMinCompetitive} tracking, as these optimizations are not applicable + * to grouped top-N. + */ +public class GroupedTopNOperator implements Operator, Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedTopNOperator.class) + RamUsageEstimator + .shallowSizeOfInstance(List.class) * 3; + + private static final long SORT_ORDER_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.SortOrder.class); + + public record GroupedTopNOperatorFactory( + int topCount, + List elementTypes, + List encoders, + List sortOrders, + List groupKeys, + int maxPageSize, + long jumboPageBytes + ) implements OperatorFactory { + public GroupedTopNOperatorFactory { + for (ElementType e : elementTypes) { + if (e == null) { + throw new IllegalArgumentException("ElementType not known"); + } + } + if (groupKeys.isEmpty()) { + throw new IllegalArgumentException("GroupedTopNOperator requires at least one group key"); + } + } + + @Override + public GroupedTopNOperator get(DriverContext driverContext) { + return new GroupedTopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys.stream().mapToInt(Integer::intValue).toArray(), + maxPageSize, + jumboPageBytes + ); + } + + @Override + public String describe() { + return "GroupedTopNOperator[count=" + + topCount + + ", elementTypes=" + + elementTypes + + ", encoders=" + + encoders + + ", sortOrders=" + + sortOrders + + ", groupKeys=" + + groupKeys + + "]"; + } + } + + private final BlockFactory blockFactory; + private final CircuitBreaker breaker; + private final int maxPageSize; + private final long jumboPageBytes; + private final int topCount; + private final List elementTypes; + private final List encoders; + private final List sortOrders; + private final int[] groupKeys; + private final boolean[] channelInKey; + private final PositionKeyEncoder keyEncoder; + + private BytesRefHashTable keysHash; + private GroupedQueue inputQueue; + private GroupedRow spare; + + private ReleasableIterator output; + + private long receiveNanos; + private long emitNanos; + private int pagesReceived; + private int pagesEmitted; + private long rowsReceived; + private long rowsEmitted; + + public GroupedTopNOperator( + BlockFactory blockFactory, + CircuitBreaker breaker, + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys, + int maxPageSize, + long jumboPageBytes + ) { + BytesRefHashTable keysHash = null; + GroupedQueue inputQueue = null; + boolean success = false; + try { + keysHash = HashImplFactory.newBytesRefHash(blockFactory); + inputQueue = new GroupedQueue(breaker, blockFactory.bigArrays(), topCount); + success = true; + } finally { + if (success == false) { + Releasables.close(keysHash, inputQueue); + } + } + this.keyEncoder = new PositionKeyEncoder(groupKeys, elementTypes); + this.keysHash = keysHash; + this.inputQueue = inputQueue; + this.blockFactory = blockFactory; + this.breaker = breaker; + this.maxPageSize = maxPageSize; + this.jumboPageBytes = jumboPageBytes; + this.topCount = topCount; + this.elementTypes = elementTypes; + this.encoders = encoders; + this.sortOrders = sortOrders; + this.groupKeys = groupKeys; + this.channelInKey = new boolean[elementTypes.size()]; + for (TopNOperator.SortOrder so : sortOrders) { + channelInKey[so.channel()] = true; + } + } + + @Override + public boolean needsInput() { + return output == null; + } + + @Override + public void addInput(Page page) { + long start = System.nanoTime(); + try { + if (this.topCount <= 0) { + return; + } + GroupedRowFiller rowFiller = new GroupedRowFiller(elementTypes, encoders, sortOrders, channelInKey, page); + for (int pos = 0; pos < page.getPositionCount(); pos++) { + BytesRef key = keyEncoder.encode(page, pos); + long hashOrd = keysHash.add(key); + long groupId = BlockHash.hashOrdToGroup(hashOrd); + processRow(rowFiller, pos, groupId); + } + } finally { + page.releaseBlocks(); + pagesReceived++; + rowsReceived += page.getPositionCount(); + receiveNanos += System.nanoTime() - start; + } + } + + private void processRow(GroupedRowFiller rowFiller, int position, long groupId) { + if (spare == null) { + spare = new GroupedRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); + } else { + spare.clear(); + } + spare.groupId = groupId; + rowFiller.writeSortKey(position, spare); + + var nextSpare = inputQueue.addRow(spare); + if (nextSpare != spare) { + var insertedRow = spare; + spare = nextSpare; + rowFiller.writeValues(position, insertedRow); + } + } + + @Override + public void finish() { + if (output == null) { + long start = System.nanoTime(); + output = buildResult(); + emitNanos += System.nanoTime() - start; + } + } + + @Override + public boolean isFinished() { + return output != null && output.hasNext() == false; + } + + @Override + public boolean canProduceMoreDataWithoutExtraInput() { + return output != null && output.hasNext(); + } + + @Override + public Page getOutput() { + if (output == null || output.hasNext() == false) { + return null; + } + Page ret = output.next(); + pagesEmitted++; + rowsEmitted += ret.getPositionCount(); + return ret; + } + + @Override + public void close() { + Releasables.closeExpectNoException(spare, inputQueue, output, keysHash); + inputQueue = null; + output = null; + } + + @Override + public long ramBytesUsed() { + long arrHeader = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + long ref = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long size = SHALLOW_SIZE; + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * elementTypes.size()); + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size()); + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size()); + size += RamUsageEstimator.sizeOf(groupKeys); + size += RamUsageEstimator.sizeOf(channelInKey); + size += sortOrders.size() * SORT_ORDER_SIZE; + size += keyEncoder.ramBytesUsed(); + if (keysHash != null) { + size += keysHash.ramBytesUsed(); + } + if (inputQueue != null) { + size += inputQueue.ramBytesUsed(); + } + if (spare != null) { + size += spare.ramBytesUsed(); + } + return size; + } + + @Override + public Status status() { + return new GroupedTopNOperatorStatus( + receiveNanos, + emitNanos, + inputQueue != null ? inputQueue.size() : 0, + keysHash != null ? keysHash.size() : 0, + ramBytesUsed(), + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } + + @Override + public String toString() { + return "GroupedTopNOperator[count=" + + inputQueue + + ", elementTypes=" + + elementTypes + + ", encoders=" + + encoders + + ", sortOrders=" + + sortOrders + + ", groupKeys=" + + Arrays.toString(groupKeys) + + "]"; + } + + /** + * Build the result iterator. Moves all rows from the {@link #inputQueue} and + * {@link #close}s it. + */ + private ReleasableIterator buildResult() { + if (spare != null) { + spare.close(); + spare = null; + } + + if (inputQueue.size() == 0) { + return ReleasableIterator.empty(); + } + + List rows = inputQueue.popAll(); + inputQueue.close(); + keysHash.close(); + inputQueue = null; + keysHash = null; + return new Result(rows); + } + + private class Result implements ReleasableIterator { + private final List rows; + private int r; + + private Result(List rows) { + this.rows = rows; + } + + @Override + public boolean hasNext() { + return r < rows.size(); + } + + @Override + public Page next() { + long start = System.nanoTime(); + int size = Math.min(maxPageSize, rows.size() - r); + if (size <= 0) { + throw new IllegalStateException("can't make empty pages. " + size + " must be > 0"); + } + ResultBuilder[] builders = new ResultBuilder[elementTypes.size()]; + try { + for (int b = 0; b < builders.length; b++) { + builders[b] = ResultBuilder.resultBuilderFor(blockFactory, elementTypes.get(b), encoders.get(b), channelInKey[b], size); + } + int rEnd = r + size; + while (r < rEnd) { + try (GroupedRow row = rows.set(r++, null)) { + readKeys(builders, row.keys().bytesRefView()); + readValues(builders, row.values().bytesRefView()); + } + if (totalSize(builders) > jumboPageBytes) { + break; + } + } + + return new Page(ResultBuilder.buildAll(builders)); + } finally { + Releasables.close(builders); + emitNanos += System.nanoTime() - start; + } + } + + private long totalSize(ResultBuilder[] builders) { + long total = 0; + for (ResultBuilder b : builders) { + total += b.estimatedBytes(); + } + return total; + } + + @Override + public void close() { + Releasables.close(rows); + } + + private void readKeys(ResultBuilder[] builders, BytesRef keys) { + for (TopNOperator.SortOrder so : sortOrders) { + if (keys.bytes[keys.offset] == so.nul()) { + keys.offset++; + keys.length--; + continue; + } + keys.offset++; + keys.length--; + builders[so.channel()].decodeKey(keys, so.asc()); + } + if (keys.length != 0) { + throw new IllegalArgumentException("didn't read all keys"); + } + } + + private void readValues(ResultBuilder[] builders, BytesRef values) { + for (ResultBuilder builder : builders) { + builder.decodeValue(values); + } + if (values.length != 0) { + throw new IllegalArgumentException("didn't read all values"); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java new file mode 100644 index 0000000000000..6e641c1316fa9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class GroupedTopNOperatorStatus implements Operator.Status { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Operator.Status.class, + "groupedtopn", + GroupedTopNOperatorStatus::new + ); + + private final long receiveNanos; + private final long emitNanos; + private final int occupiedRows; + private final long groupCount; + private final long ramBytesUsed; + private final int pagesReceived; + private final int pagesEmitted; + private final long rowsReceived; + private final long rowsEmitted; + + public GroupedTopNOperatorStatus( + long receiveNanos, + long emitNanos, + int occupiedRows, + long groupCount, + long ramBytesUsed, + int pagesReceived, + int pagesEmitted, + long rowsReceived, + long rowsEmitted + ) { + this.receiveNanos = receiveNanos; + this.emitNanos = emitNanos; + this.occupiedRows = occupiedRows; + this.groupCount = groupCount; + this.ramBytesUsed = ramBytesUsed; + this.pagesReceived = pagesReceived; + this.pagesEmitted = pagesEmitted; + this.rowsReceived = rowsReceived; + this.rowsEmitted = rowsEmitted; + } + + GroupedTopNOperatorStatus(StreamInput in) throws IOException { + this.receiveNanos = in.readVLong(); + this.emitNanos = in.readVLong(); + this.occupiedRows = in.readVInt(); + this.groupCount = in.readVLong(); + this.ramBytesUsed = in.readVLong(); + + this.pagesReceived = in.readVInt(); + this.pagesEmitted = in.readVInt(); + this.rowsReceived = in.readVLong(); + this.rowsEmitted = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(receiveNanos); + out.writeVLong(emitNanos); + + out.writeVInt(occupiedRows); + out.writeVLong(groupCount); + out.writeVLong(ramBytesUsed); + + out.writeVInt(pagesReceived); + out.writeVInt(pagesEmitted); + out.writeVLong(rowsReceived); + out.writeVLong(rowsEmitted); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public long receiveNanos() { + return receiveNanos; + } + + public long emitNanos() { + return emitNanos; + } + + public int occupiedRows() { + return occupiedRows; + } + + public long groupCount() { + return groupCount; + } + + public long ramBytesUsed() { + return ramBytesUsed; + } + + public int pagesReceived() { + return pagesReceived; + } + + public int pagesEmitted() { + return pagesEmitted; + } + + public long rowsReceived() { + return rowsReceived; + } + + public long rowsEmitted() { + return rowsEmitted; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("receive_nanos", receiveNanos); + if (builder.humanReadable()) { + builder.field("receive_time", TimeValue.timeValueNanos(receiveNanos).toString()); + } + builder.field("emit_nanos", emitNanos); + if (builder.humanReadable()) { + builder.field("emit_time", TimeValue.timeValueNanos(emitNanos).toString()); + } + builder.field("occupied_rows", occupiedRows); + builder.field("group_count", groupCount); + builder.field("ram_bytes_used", ramBytesUsed); + builder.field("ram_used", ByteSizeValue.ofBytes(ramBytesUsed)); + builder.field("pages_received", pagesReceived); + builder.field("pages_emitted", pagesEmitted); + builder.field("rows_received", rowsReceived); + builder.field("rows_emitted", rowsEmitted); + return builder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + GroupedTopNOperatorStatus that = (GroupedTopNOperatorStatus) o; + return receiveNanos == that.receiveNanos + && emitNanos == that.emitNanos + && occupiedRows == that.occupiedRows + && groupCount == that.groupCount + && ramBytesUsed == that.ramBytesUsed + && pagesReceived == that.pagesReceived + && pagesEmitted == that.pagesEmitted + && rowsReceived == that.rowsReceived + && rowsEmitted == that.rowsEmitted; + } + + @Override + public int hashCode() { + return Objects.hash( + receiveNanos, + emitNanos, + occupiedRows, + groupCount, + ramBytesUsed, + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.minimumCompatible(); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java new file mode 100644 index 0000000000000..bead85e16c895 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; + +public class GroupedQueueTests extends ESTestCase { + private final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); + private final CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + private final BlockFactory blockFactory = new BlockFactory(breaker, bigArrays); + + @After + public void allMemoryReleased() throws Exception { + MockBigArrays.ensureAllArraysAreReleased(); + + assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); + assertThat("Not all blocks were released", blockFactory.breaker().getUsed(), equalTo(0L)); + } + + public void testCleanup() { + int topCount = 5; + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { + assertThat(queue.size(), equalTo(0)); + + for (int i = 0; i < topCount * 2; i++) { + addRow(queue, i % 3, i * 10); + } + } + } + + public void testAddWhenHeapNotFull() { + int topCount = 5; + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { + for (int i = 0; i < topCount; i++) { + GroupedRow row = createRow(breaker, i % 2, i * 10); + GroupedRow result = queue.addRow(row); + assertThat(result, nullValue()); + assertThat(queue.size(), equalTo(i + 1)); + } + } + } + + public void testAddWhenHeapFullAndRowQualifies() { + int topCount = 3; + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { + fillQueueToCapacity(queue, topCount); + + try (GroupedRow evicted = queue.addRow(createRow(breaker, 0, 5))) { + assertRowValues(evicted, 0, 20, 40); + } + } + } + + public void testAddWhenHeapFullAndRowDoesNotQualify() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 3)) { + addRows(queue, 0, 30, 40, 50); + + try (GroupedRow row = createRow(breaker, 0, 60)) { + GroupedRow result = queue.addRow(row); + assertThat(result, sameInstance(row)); + } + } + } + + public void testAddWithDifferentGroupKeys() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 2)) { + assertThat(queue.addRow(createRow(breaker, 0, 10)), nullValue()); + assertThat(queue.addRow(createRow(breaker, 1, 20)), nullValue()); + assertThat(queue.addRow(createRow(breaker, 0, 30)), nullValue()); + assertThat(queue.addRow(createRow(breaker, 1, 40)), nullValue()); + assertThat(queue.size(), equalTo(4)); + + try (GroupedRow evicted = queue.addRow(createRow(breaker, 0, 5))) { + assertThat(evicted, notNullValue()); + assertRowValues(evicted, 0, 30, 60); + } + try (GroupedRow evicted = queue.addRow(createRow(breaker, 1, 15))) { + assertThat(evicted, notNullValue()); + assertRowValues(evicted, 1, 40, 80); + } + assertThat(queue.size(), equalTo(4)); + + try (GroupedRow row = queue.addRow(createRow(breaker, 0, 50))) { + assertThat(row, notNullValue()); + assertRowValues(row, 0, 50, 100); + } + try (GroupedRow row = queue.addRow(createRow(breaker, 1, 50))) { + assertThat(row, notNullValue()); + assertRowValues(row, 1, 50, 100); + } + assertThat(queue.size(), equalTo(4)); + } + } + + public void testRamBytesUsedEmpty() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + assertRamBytesUsedConsistent(queue); + } + } + + /** + * Verifies that ramBytesUsed() accounts for at least the shallow size and grows with content. + * We can't use RamUsageTester for BigArrays-backed structures due to module access restrictions. + */ + private void assertRamBytesUsedConsistent(GroupedQueue queue) { + long reported = queue.ramBytesUsed(); + assertThat("ramBytesUsed should be positive", reported, greaterThan(0L)); + } + + public void testRamBytesUsedPartiallyFilled() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + long emptySize = queue.ramBytesUsed(); + addRows(queue, 0, 10, 20, 30); + long filledSize = queue.ramBytesUsed(); + assertThat("RAM usage should grow after adding rows", filledSize, greaterThan(emptySize)); + } + } + + public void testRamBytesUsedAtCapacity() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + long emptySize = queue.ramBytesUsed(); + addRows(queue, 0, 10, 20, 30, 40, 50); + long oneGroupSize = queue.ramBytesUsed(); + addRows(queue, 1, 10, 20, 30, 40, 50); + addRows(queue, 2, 10, 20, 30, 40, 50); + long threeGroupSize = queue.ramBytesUsed(); + assertThat("RAM should grow with first group", oneGroupSize, greaterThan(emptySize)); + assertThat("RAM should grow with more groups", threeGroupSize, greaterThan(oneGroupSize)); + } + } + + public void testPopAllSortedBySortKey() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + addRows(queue, 0, 30, 10, 50); + addRows(queue, 1, 20, 40); + addRows(queue, 2, 15, 25, 35); + assertQueueContents( + queue, + List.of( + Tuple.tuple(0, 10), + Tuple.tuple(2, 15), + Tuple.tuple(1, 20), + Tuple.tuple(2, 25), + Tuple.tuple(0, 30), + Tuple.tuple(2, 35), + Tuple.tuple(1, 40), + Tuple.tuple(0, 50) + ) + ); + } + } + + private GroupedRow createRow(CircuitBreaker breaker, int groupKey, int sortKey) { + IntBlock groupKeyBlock = blockFactory.newIntBlockBuilder(1).appendInt(groupKey).build(); + IntBlock keyBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey).build(); + IntBlock valueBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey * 2).build(); + GroupedRow row = new GroupedRow(breaker, 32, 64); + row.groupId = groupKey; + var filler = new GroupedRowFiller( + List.of(ElementType.INT, ElementType.INT, ElementType.INT), + List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_UNSORTABLE), + SORT_ORDERS, + new boolean[] { false, true, false }, + new Page(groupKeyBlock, keyBlock, valueBlock) + ); + try { + filler.writeSortKey(0, row); + filler.writeValues(0, row); + } finally { + Releasables.close(groupKeyBlock, keyBlock, valueBlock); + } + return row; + } + + private static void assertRowValues(GroupedRow row, int expectedGroupKey, int expectedSortKey, int expectedValue) { + assertThat(row.groupId, equalTo(expectedGroupKey)); + + BytesRef keys = row.keys().bytesRefView(); + assertThat( + TopNEncoder.DEFAULT_SORTABLE.decodeInt(new BytesRef(keys.bytes, keys.offset + 1, keys.length - 1)), + equalTo(expectedSortKey) + ); + + BytesRef values = row.values().bytesRefView(); + BytesRef reader = new BytesRef(values.bytes, values.offset, values.length); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader), equalTo(expectedValue)); + } + + private void addRow(GroupedQueue queue, int groupKey, int value) { + GroupedRow row = createRow(breaker, groupKey, value); + Releasables.close(queue.addRow(row)); + } + + private void fillQueueToCapacity(GroupedQueue queue, int capacity) { + addRows(queue, 0, IntStream.range(0, capacity).map(i -> i * 10).toArray()); + } + + private void addRows(GroupedQueue queue, int groupKey, int... values) { + for (int value : values) { + addRow(queue, groupKey, value); + } + } + + private static final List SORT_ORDERS = List.of(new TopNOperator.SortOrder(1, true, false)); + + private static void assertQueueContents(GroupedQueue queue, List> groupAndSortKeys) { + assertThat(queue.size(), equalTo(groupAndSortKeys.size())); + List actual = queue.popAll(); + for (int i = 0; i < groupAndSortKeys.size(); i++) { + Tuple expectedTuple = groupAndSortKeys.get(i); + assertRowValues(actual.get(i), expectedTuple.v1(), expectedTuple.v2(), expectedTuple.v2() * 2); + } + Releasables.close(actual); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java new file mode 100644 index 0000000000000..63b4bf30dd28e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.tests.util.RamUsageTester; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class GroupedRowTests extends ESTestCase { + private final CircuitBreaker breaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST); + + public void testCloseReleasesAllTestsNoPreAllocation() throws Exception { + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + var row = new GroupedRow(breaker, 0, 0); + row.close(); + MockBigArrays.ensureAllArraysAreReleased(); + assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); + } + + public void testCloseReleasesAllTestsWithPreAllocation() throws Exception { + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + var row = new GroupedRow(breaker, 16, 32); + row.close(); + MockBigArrays.ensureAllArraysAreReleased(); + assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); + } + + public void testRamBytesUsedEmpty() { + var row = new GroupedRow(breaker, 0, 0); + assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); + } + + public void testRamBytesUsedSmall() { + var row = new GroupedRow(breaker, 0, 0); + row.keys().append(randomByte()); + row.values().append(randomByte()); + assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); + } + + public void testRamBytesUsedBig() { + var row = new GroupedRow(breaker, 0, 0); + for (int i = 0; i < 10000; i++) { + row.keys().append(randomByte()); + row.values().append(randomByte()); + } + assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); + } + + public void testRamBytesUsedPreAllocated() { + var row = new GroupedRow(breaker, 64, 128); + assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); + } + + private long expectedRamBytesUsed(GroupedRow row) { + var expected = RamUsageTester.ramUsed(row); + expected -= RamUsageTester.ramUsed(breaker); + expected -= sharedRowBytes(); + expected += undercountedBytesForRow(row); + return expected; + } + + private static long sharedRowBytes() { + return RamUsageTester.ramUsed("topn"); + } + + static long undercountedBytesForRow(GroupedRow row) { + return emptyByteArrayOverhead(row.values()); + } + + private static long emptyByteArrayOverhead(BreakingBytesRefBuilder builder) { + return builder.bytes().length == 0 ? RamUsageTester.ramUsed(new byte[0]) : 0L; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java new file mode 100644 index 0000000000000..25003e93124fc --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class GroupedTopNOperatorStatusTests extends AbstractWireSerializingTestCase { + public static GroupedTopNOperatorStatus simple() { + return new GroupedTopNOperatorStatus(100, 40, 10, 5, 2000, 123, 123, 111, 222); + } + + public static String simpleToJson() { + return """ + { + "receive_nanos" : 100, + "receive_time" : "100nanos", + "emit_nanos" : 40, + "emit_time" : "40nanos", + "occupied_rows" : 10, + "group_count" : 5, + "ram_bytes_used" : 2000, + "ram_used" : "1.9kb", + "pages_received" : 123, + "pages_emitted" : 123, + "rows_received" : 111, + "rows_emitted" : 222 + }"""; + } + + public void testToXContent() { + assertThat(Strings.toString(simple(), true, true), equalTo(simpleToJson())); + } + + @Override + protected Writeable.Reader instanceReader() { + return GroupedTopNOperatorStatus::new; + } + + @Override + protected GroupedTopNOperatorStatus createTestInstance() { + return new GroupedTopNOperatorStatus( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeInt(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeInt(), + randomNonNegativeInt(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected GroupedTopNOperatorStatus mutateInstance(GroupedTopNOperatorStatus instance) { + long receiveNanos = instance.receiveNanos(); + long emitNanos = instance.emitNanos(); + int occupiedRows = instance.occupiedRows(); + long groupCount = instance.groupCount(); + long ramBytesUsed = instance.ramBytesUsed(); + int pagesReceived = instance.pagesReceived(); + int pagesEmitted = instance.pagesEmitted(); + long rowsReceived = instance.rowsReceived(); + long rowsEmitted = instance.rowsEmitted(); + switch (between(0, 8)) { + case 0: + receiveNanos = randomValueOtherThan(receiveNanos, ESTestCase::randomNonNegativeLong); + break; + case 1: + emitNanos = randomValueOtherThan(emitNanos, ESTestCase::randomNonNegativeLong); + break; + case 2: + occupiedRows = randomValueOtherThan(occupiedRows, ESTestCase::randomNonNegativeInt); + break; + case 3: + groupCount = randomValueOtherThan(groupCount, ESTestCase::randomNonNegativeLong); + break; + case 4: + ramBytesUsed = randomValueOtherThan(ramBytesUsed, ESTestCase::randomNonNegativeLong); + break; + case 5: + pagesReceived = randomValueOtherThan(pagesReceived, ESTestCase::randomNonNegativeInt); + break; + case 6: + pagesEmitted = randomValueOtherThan(pagesEmitted, ESTestCase::randomNonNegativeInt); + break; + case 7: + rowsReceived = randomValueOtherThan(rowsReceived, ESTestCase::randomNonNegativeLong); + break; + case 8: + rowsEmitted = randomValueOtherThan(rowsEmitted, ESTestCase::randomNonNegativeLong); + break; + default: + throw new IllegalArgumentException(); + } + return new GroupedTopNOperatorStatus( + receiveNanos, + emitNanos, + occupiedRows, + groupCount, + ramBytesUsed, + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java new file mode 100644 index 0000000000000..5edc16ce03e02 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -0,0 +1,683 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.DocBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.IndexedByShardIdFromList; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.topn.TopNOperator.SortOrder; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.compute.test.TestBlockBuilder; +import org.elasticsearch.compute.test.TestDriverFactory; +import org.elasticsearch.compute.test.TestDriverRunner; +import org.elasticsearch.compute.test.operator.blocksource.ListRowsBlockSourceOperator; +import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import static org.elasticsearch.compute.data.ElementType.DOC; +import static org.elasticsearch.compute.data.ElementType.LONG; +import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_UNSORTABLE; +import static org.elasticsearch.compute.test.BlockTestUtils.append; +import static org.elasticsearch.compute.test.BlockTestUtils.readInto; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class GroupedTopNOperatorTests extends TopNOperatorTests { + private static final int TOP_COUNT = 4; + + @Override + protected int[] groupKeys() { + return new int[] { 0 }; + } + + @Override + protected void testRandomTopN(boolean asc, DriverContext context) { + int limit = randomIntBetween(1, 20); + List> inputValues = randomList(0, 1000, () -> Tuple.tuple(ESTestCase.randomLong(), ESTestCase.randomLong(9))); + List> expectedValues = computeTopN(inputValues, limit, asc); + List> outputValues = topNTwoLongColumns( + context, + inputValues, + limit, + List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(0, asc, false)), + new int[] { 1 } + ); + + assertThat(outputValues, equalTo(expectedValues)); + } + + @Override + protected List> expectedTopRowOriented(List> rowOriented, List sortOrders, int topCount) { + return computeTopN(rowOriented, IntStream.of(groupKeys()).boxed().toList(), sortOrders, topCount); + } + + @Override + public void testStatus() { + BlockFactory blockFactory = driverContext().blockFactory(); + try (Operator op = simple(SimpleOptions.DEFAULT).get(driverContext())) { + Operator.Status status = op.status(); + assertThat(status, instanceOf(GroupedTopNOperatorStatus.class)); + GroupedTopNOperatorStatus groupedStatus = (GroupedTopNOperatorStatus) status; + assertThat(groupedStatus.occupiedRows(), equalTo(0)); + assertThat(groupedStatus.groupCount(), equalTo(0L)); + assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(groupedStatus.pagesReceived(), equalTo(0)); + assertThat(groupedStatus.pagesEmitted(), equalTo(0)); + assertThat(groupedStatus.rowsReceived(), equalTo(0L)); + assertThat(groupedStatus.rowsEmitted(), equalTo(0L)); + + Page p = new Page( + blockFactory.newConstantLongBlockWith(1, 10), + blockFactory.newLongArrayVector(new long[] { 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L }, 10).asBlock() + ); + op.addInput(p); + status = op.status(); + groupedStatus = (GroupedTopNOperatorStatus) status; + assertThat(groupedStatus.receiveNanos(), greaterThan(0L)); + assertThat(groupedStatus.emitNanos(), equalTo(0L)); + assertThat(groupedStatus.occupiedRows(), equalTo(8)); + assertThat(groupedStatus.groupCount(), equalTo(2L)); + assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(groupedStatus.pagesReceived(), equalTo(1)); + assertThat(groupedStatus.pagesEmitted(), equalTo(0)); + assertThat(groupedStatus.rowsReceived(), equalTo(10L)); + assertThat(groupedStatus.rowsEmitted(), equalTo(0L)); + } + } + + public void testBasicTopN() { + List values = Arrays.asList(2L, 1L, 4L, null, 4L, null); + assertThat(topNLong(values, 1, true, false), equalTo(Arrays.asList(1L, 2L, 4L, null))); + assertThat(topNLong(values, 1, false, false), equalTo(Arrays.asList(4L, 2L, 1L, null))); + assertThat(topNLong(values, 1, true, true), equalTo(Arrays.asList(null, 1L, 2L, 4L))); + assertThat(topNLong(values, 1, false, true), equalTo(Arrays.asList(null, 4L, 2L, 1L))); + assertThat(topNLong(values, 2, true, false), equalTo(Arrays.asList(1L, 2L, 4L, 4L, null, null))); + assertThat(topNLong(values, 2, false, false), equalTo(Arrays.asList(4L, 4L, 2L, 1L, null, null))); + assertThat(topNLong(values, 2, true, true), equalTo(Arrays.asList(null, null, 1L, 2L, 4L, 4L))); + assertThat(topNLong(values, 2, false, true), equalTo(Arrays.asList(null, null, 4L, 4L, 2L, 1L))); + } + + private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { + return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); + } + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new TupleLongLongBlockSourceOperator( + blockFactory, + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(ESTestCase.randomLong(), ESTestCase.randomLong(9))), + between(1, size * 2) + ); + } + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new GroupedTopNOperator.GroupedTopNOperatorFactory( + TOP_COUNT, + List.of(LONG, LONG), + List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(0, true, false)), + List.of(1), + pageSize, + Long.MAX_VALUE + ); + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return equalTo( + "GroupedTopNOperator[count=4, elementTypes=[LONG, LONG], encoders=[DefaultUnsortable, DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], groupKeys=[1]]" + ); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo( + "GroupedTopNOperator[count=0/0/4, elementTypes=[LONG, LONG], encoders=[DefaultUnsortable, DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], groupKeys=[1]]" + ); + } + + @Override + protected void assertSimpleOutput(List input, List results) { + for (int i = 0; i < results.size() - 1; i++) { + assertThat(results.get(i).getPositionCount(), equalTo(pageSize)); + } + assertThat(results.get(results.size() - 1).getPositionCount(), lessThanOrEqualTo(pageSize)); + List> values = input.stream() + .flatMap( + page -> IntStream.range(0, page.getPositionCount()) + .filter(p -> false == page.getBlock(0).isNull(p)) + .mapToObj(p -> Tuple.tuple(((LongBlock) page.getBlock(0)).getLong(p), ((LongBlock) page.getBlock(1)).getLong(p))) + + ) + .toList(); + var expected = computeTopN(values, TOP_COUNT, true); + assertThat( + results.stream() + .flatMap( + page -> IntStream.range(0, page.getPositionCount()) + .mapToObj(i -> Tuple.tuple(page.getBlock(0).getLong(i), page.getBlock(1).getLong(i))) + ) + .toList(), + equalTo(expected) + ); + } + + /** + * Tests that the SORTED input ordering optimization short-circuiting addInput() doesn't incorrectly skip rows + * belonging to groups not yet populated when another group's row is rejected. + * + *

{@code
+     * Scenario (SORT ASC, LIMIT 1 BY group):
+     * - Page 1: (group=0), (group=1) → both groups populated
+     * - Page 2: (group=0), (group=2) → (group=0) rejected from full group 0 → break skips (group=2), which was empty
+     * Expected groups: {0, 1, 2}
+     * Bug result: {0, 1} (group 2 missing)
+     * }
+ */ + public void testSortedInputWithMultipleGroups() { + int topCount = 1; + int[] groupKeys = new int[] { 1 }; + List elementTypes = List.of(ElementType.INT, ElementType.INT); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, true)); + + BlockFactory bf = driverContext().blockFactory(); + + // Page 1: sorted ASC by sort key + Page page1; + try ( + Block.Builder sortCol = ElementType.INT.newBlockBuilder(2, bf); + Block.Builder groupCol = ElementType.INT.newBlockBuilder(2, bf) + ) { + append(sortCol, 1); + append(sortCol, 3); + append(groupCol, 0); + append(groupCol, 1); + page1 = new Page(sortCol.build(), groupCol.build()); + } + + // Page 2: sorted ASC by sort key + Page page2; + try ( + Block.Builder sortCol = ElementType.INT.newBlockBuilder(2, bf); + Block.Builder groupCol = ElementType.INT.newBlockBuilder(2, bf) + ) { + append(sortCol, 2); + append(sortCol, 4); + append(groupCol, 0); + append(groupCol, 2); + page2 = new Page(sortCol.build(), groupCol.build()); + } + + List> actual = new ArrayList<>(); + DriverContext driverContext = driverContext(); + + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(List.of(page1, page2).iterator()), + List.of( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + + // 3 groups, each with 1 value, ordered ASC: [1, 3, 4] + assertThat(actual.get(0), equalTo(List.of(1, 3, 4))); + assertThat(actual.get(1), equalTo(List.of(0, 1, 2))); + } + + public void testMultivalueGroupKey() { + DriverContext driverContext = driverContext(); + BlockFactory blockFactory = driverContext.blockFactory(); + + int topCount = 1; + int[] groupKeys = new int[] { 2 }; // group key at channel 2 + List elementTypes = List.of(LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page( + BlockUtils.fromList( + blockFactory, + List.of( + // (To keep indentation) + List.of(10L, 100L, List.of(1L, 2L)), + List.of(20L, 200L, 1L), + List.of(30L, 300L, 3L) + ) + ) + ); + + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(List.of(page).iterator()), + List.of( + new GroupedTopNOperator( + blockFactory, + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + + // List semantics: [1,2] is one group, 1 is another, 3 is another. Sorted ASC by sort key. + assertThat(actual.get(0), equalTo(List.of(10L, 20L, 30L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(100L, 200L, 300L))); // Value + assertThat(actual.get(2), equalTo(List.of(List.of(1L, 2L), 1L, 3L))); // Group key (MV preserved) + } + + public void testMultivalueGroupKeyDuplicateWinner() { + DriverContext driverContext = driverContext(); + BlockFactory bf = driverContext.blockFactory(); + + int topCount = 1; + int[] groupKeys = new int[] { 2 }; + List elementTypes = List.of(LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page(BlockUtils.fromList(bf, List.of(List.of(5L, 50L, List.of(1L, 2L)), List.of(10L, 100L, 1L)))); + + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(List.of(page).iterator()), + List.of( + new GroupedTopNOperator( + bf, + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + + // List semantics: [1,2] is one group, 1 is another. Sorted ASC by sort key. + assertThat(actual.get(0), equalTo(List.of(5L, 10L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(50L, 100L))); // Value + assertThat(actual.get(2), equalTo(List.of(List.of(1L, 2L), 1L))); // Group key (MV preserved) + } + + /** + * Tests list semantics with two multivalue group keys: a row with group_key1=[1, 2] + * and group_key2=[10, 20] belongs to exactly one group keyed by ([1,2], [10,20]). + * No cartesian product expansion occurs. + */ + public void testMultipleMultivalueGroupKeys() { + DriverContext driverContext = driverContext(); + BlockFactory bf = driverContext.blockFactory(); + + int topCount = 1; + int[] groupKeys = new int[] { 2, 3 }; // two group keys at channels 2 and 3 + List elementTypes = List.of(LONG, LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page( + BlockUtils.fromList( + bf, + List.of(List.of(10L, 100L, List.of(1L, 2L), List.of(10L, 20L)), List.of(5L, 50L, 1L, 10L), List.of(15L, 150L, 2L, 20L)) + ) + ); + + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(List.of(page).iterator()), + List.of( + new GroupedTopNOperator( + bf, + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + + // List semantics: 3 distinct groups, each with 1 row, sorted ASC by sort key + assertThat(actual.get(0), equalTo(List.of(5L, 10L, 15L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(50L, 100L, 150L))); // Value + assertThat(actual.get(2), equalTo(List.of(1L, List.of(1L, 2L), 2L))); // Group key 1 (MV preserved) + assertThat(actual.get(3), equalTo(List.of(10L, List.of(10L, 20L), 20L))); // Group key 2 (MV preserved) + } + + public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { + topNShardContextManagementAux(2, Stream.generate(() -> true).limit(4).toList()); + } + + public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { + topNShardContextManagementAux(1, List.of(true, false, false, true)); + } + + private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { + List> values = Arrays.asList( + Arrays.asList(new BlockUtils.Doc(0, 10, 100), 1L, 1L), + Arrays.asList(new BlockUtils.Doc(1, 20, 200), 2L, 2L), + Arrays.asList(new BlockUtils.Doc(2, 30, 300), null, 1L), + Arrays.asList(new BlockUtils.Doc(3, 40, 400), -3L, 2L) + ); + + List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); + var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); + var pages = topNMultipleColumns( + driverContext(), + new ListRowsBlockSourceOperator(driverContext().blockFactory(), List.of(DOC, LONG, LONG), values) { + @Override + protected TestBlockBuilder getTestBlockBuilder(int b) { + return b == 0 ? new TestBlockBuilder.DocBlockBuilder(blockFactory, shardRefCounters) : super.getTestBlockBuilder(b); + } + }, + limit, + List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(1, true, false)), + new int[] { 2 } + ); + try { + refCountedList.forEach(RefCounted::decRef); + + assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); + assertThat(pageToValues(pages), equalTo(computeTopN(values, 2, 1, limit, true))); + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } + } finally { + Releasables.close(pages); + } + } + + public void testRandomMultipleColumns() { + DriverContext driverContext = driverContext(); + int rows = randomIntBetween(50, 100); + int topCount = randomIntBetween(1, 10); + int blocksCount = randomIntBetween(10, 20); + int sortingByColumns = randomIntBetween(2, 3); + int groupKeysCount = randomIntBetween(2, 3); + + RandomBlocksResult randomBlocksResult = generateRandomSingleValueBlocks(rows, blocksCount, driverContext); + + List sortColumns = new ArrayList<>(); + for (int i = 0; i < sortingByColumns; i++) { + sortColumns.add( + randomValueOtherThanMany( + c -> randomBlocksResult.validSortKeys[c] == false || sortColumns.contains(c), + () -> randomIntBetween(0, blocksCount - 1) + ) + ); + } + + List groupKeys = new ArrayList<>(); + for (int i = 0; i < groupKeysCount; i++) { + groupKeys.add( + randomValueOtherThanMany( + c -> sortColumns.contains(c) || groupKeys.contains(c) || randomBlocksResult.validSortKeys[c] == false, + () -> randomIntBetween(0, blocksCount - 1) + ) + ); + } + + List uniqueOrders = sortColumns.stream().map(column -> new SortOrder(column, randomBoolean(), randomBoolean())).toList(); + + List results = new TestDriverRunner().builder(driverContext) + .input(List.of(new Page(randomBlocksResult.blocks.toArray(Block[]::new))).iterator()) + .run( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + randomBlocksResult.elementTypes, + randomBlocksResult.encoders, + uniqueOrders.stream().toList(), + groupKeys.stream().mapToInt(Integer::intValue).toArray(), + rows, + Long.MAX_VALUE + ) + ); + List> actualValues = new ArrayList<>(); + for (Page p : results) { + actualValues.addAll(readAsRowsSingleValue(p)); + p.releaseBlocks(); + } + + List> topNExpectedValues = computeTopN(randomBlocksResult.expectedValues, groupKeys, uniqueOrders, topCount); + + // We verify the output we got from the operator is sorted, but since we're asserting the results, we also need to handle ties which + // the operator doesn't actually guarantee any order for. + Comparator> sortOrderComparator = comparatorFromSortOrders(uniqueOrders); + assertThat(isSorted(actualValues, sortOrderComparator), equalTo(true)); + + // Given that sorting on repeated sort keys (Specially booleans, nulls...) may include arbitrary rows, + // we'll only assert the expected groups and keys, by including them in this signature + Function, List> signature = row -> { + List sig = new ArrayList<>(); + for (int gk : groupKeys.stream().mapToInt(Integer::intValue).toArray()) { + sig.add(row.get(gk)); + } + for (SortOrder so : uniqueOrders) { + sig.add(row.get(so.channel())); + } + return sig; + }; + + Map, Long> actualCounts = actualValues.stream() + .map(signature) + .collect(Collectors.groupingBy(s -> s, Collectors.counting())); + Map, Long> expectedCounts = topNExpectedValues.stream() + .map(signature) + .collect(Collectors.groupingBy(s -> s, Collectors.counting())); + + assertThat(actualCounts, equalTo(expectedCounts)); + } + + private static Comparator> comparatorFromSortOrders(List sortOrders) { + return (row1, row2) -> { + assertEquals(row1.size(), row2.size()); + for (SortOrder order : sortOrders) { + int cmp = compareValues(order).compare(row1.get(order.channel()), row2.get(order.channel())); + if (cmp != 0) { + return cmp; + } + } + return 0; + }; + } + + private static final Comparator> TIE_BREAKING_COMPARATOR = (row1, row2) -> { + for (int i = 0; i < row1.size(); i++) { + int cmp = compareValues(new SortOrder(i, true, true)).compare(row1.get(i), row2.get(i)); + if (cmp != 0) { + return cmp; + } + } + return 0; + }; + + private static boolean isSorted(List> values, Comparator> comparator) { + return IntStream.range(1, values.size()).allMatch(i -> comparator.compare(values.get(i - 1), values.get(i)) <= 0); + } + + private static Comparator compareValues(SortOrder order) { + Comparator baseComparator = order.asc() ? CASTING_COMPARATOR : CASTING_COMPARATOR.reversed(); + return order.nullsFirst() ? Comparator.nullsFirst(baseComparator) : Comparator.nullsLast(baseComparator); + } + + @SuppressWarnings("unchecked") + private static final Comparator CASTING_COMPARATOR = (o1, o2) -> ((Comparable) o1).compareTo(o2); + + private static List> computeTopN(List> inputValues, int limit, boolean ascendingOrder) { + return computeTopN(inputValues.stream().map(e -> Arrays.asList(e.v1(), e.v2())).toList(), 1, 0, limit, ascendingOrder).stream() + .map(l -> Tuple.tuple((Long) l.get(0), (Long) l.get(1))) + .toList(); + } + + private static List> computeTopN( + List> inputValues, + int groupChannel, + int sortChannel, + int limit, + boolean ascendingOrder + ) { + List> singleValueInput = new ArrayList<>(); + for (List row : inputValues) { + List rowAsObject = row.stream().map(v -> (Object) v).toList(); + singleValueInput.add(rowAsObject); + } + List sortOrders = List.of(new SortOrder(sortChannel, ascendingOrder, false)); + return new GroupedTopNOperatorTests().computeTopN(singleValueInput, List.of(groupChannel), sortOrders, limit); + } + + private List> computeTopN( + List> inputValues, + List groupChannels, + List sortOrders, + int limit + ) { + Comparator> comparator = (row1, row2) -> { + for (SortOrder order : sortOrders) { + Object v1 = row1.get(order.channel()); + Object v2 = row2.get(order.channel()); + boolean firstIsNull = v1 == null; + boolean secondIsNull = v2 == null; + + if (firstIsNull || secondIsNull) { + int nullCompare = Boolean.compare(firstIsNull, secondIsNull) * (order.nullsFirst() ? -1 : 1); + if (nullCompare != 0) { + return nullCompare; + } + continue; + } + + int cmp = CASTING_COMPARATOR.compare(v1, v2); + if (cmp != 0) { + return order.asc() ? cmp : -cmp; + } + } + return 0; + }; + + Map, List>> grouped = inputValues.stream() + .collect(Collectors.groupingBy(row -> groupChannels.stream().map(row::get).toList())); + + List> topNExpectedValues = new ArrayList<>(); + for (List> groupRows : grouped.values()) { + List> sortedGroup = groupRows.stream().sorted(comparator).limit(limit).toList(); + topNExpectedValues.addAll(sortedGroup); + } + topNExpectedValues.sort(comparator); + return topNExpectedValues; + } + + private static List> pageToValues(List pages) { + var result = new ArrayList>(); + for (Page page : pages) { + var blocks = IntStream.range(0, page.getBlockCount()).mapToObj(page::getBlock).toList(); + result.addAll( + IntStream.range(0, page.getPositionCount()) + .mapToObj(position -> blocks.stream().map(block -> getBlockValue(block, position)).toList()) + .toList() + ); + page.releaseBlocks(); + } + + return result; + } + + private static Object getBlockValue(Block block, int position) { + return block.isNull(position) ? null : switch (block) { + case LongBlock longBlock -> longBlock.getLong(position); + case DocBlock docBlock -> { + var vector = docBlock.asVector(); + yield new BlockUtils.Doc( + vector.shards().getInt(position), + vector.segments().getInt(position), + vector.docs().getInt(position) + ); + } + default -> throw new IllegalArgumentException("Unsupported block type: " + block.getClass()); + }; + } + +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 056437c98e2ec..181c581a98ba1 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -9,11 +9,14 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.tests.util.RamUsageTester; +import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHashTable; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.data.Block; @@ -26,7 +29,6 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.AlwaysReferencedIndexedByShardId; -import org.elasticsearch.compute.lucene.IndexedByShardId; import org.elasticsearch.compute.lucene.IndexedByShardIdFromList; import org.elasticsearch.compute.operator.CountingCircuitBreaker; import org.elasticsearch.compute.operator.Driver; @@ -41,16 +43,16 @@ import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.compute.test.TestDriverFactory; import org.elasticsearch.compute.test.TestDriverRunner; +import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; import org.elasticsearch.compute.test.operator.blocksource.SequenceLongBlockSourceOperator; -import org.elasticsearch.compute.test.operator.blocksource.TupleAbstractBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleDocLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.core.Tuple; import org.elasticsearch.indices.CrankyCircuitBreakerService; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ListMatcher; import org.elasticsearch.xpack.versionfield.Version; import org.hamcrest.Matcher; @@ -68,6 +70,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BiFunction; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -91,24 +94,56 @@ import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_SORTABLE; import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_UNSORTABLE; import static org.elasticsearch.compute.operator.topn.TopNEncoder.UTF8; +import static org.elasticsearch.compute.operator.topn.TopNEncoderTests.randomPointAsWKB; import static org.elasticsearch.compute.test.BlockTestUtils.append; import static org.elasticsearch.compute.test.BlockTestUtils.randomValue; import static org.elasticsearch.compute.test.BlockTestUtils.readInto; import static org.elasticsearch.core.Tuple.tuple; +import static org.elasticsearch.test.ESTestCase.between; import static org.elasticsearch.test.ListMatcher.matchesList; import static org.elasticsearch.test.MapMatcher.assertMap; import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.number.OrderingComparison.lessThanOrEqualTo; //@Repeat(iterations = 100) public class TopNOperatorTests extends OperatorTestCase { + protected final int pageSize = randomPageSize(); + + /** + * Accumulator for {@link RamUsageTester} that excludes shared objects not owned by the operator. + */ + protected static final RamUsageTester.Accumulator RAM_USAGE_ACCUMULATOR = new RamUsageTester.Accumulator() { + @Override + public long accumulateObject(Object o, long shallowSize, Map fieldValues, Collection queue) { + if (o instanceof ElementType) { + return 0; // shared + } + if (o instanceof TopNEncoder) { + return 0; // shared + } + if (o instanceof CircuitBreaker) { + return 0; // shared + } + if (o instanceof BlockFactory) { + return 0; // shared + } + if (o instanceof BigArrays) { + return 0; // shared + } + if (o instanceof BytesRefHashTable h) { + return h.ramBytesUsed(); + } + return super.accumulateObject(o, shallowSize, fieldValues, queue); + } + }; + // versions taken from org.elasticsearch.xpack.versionfield.VersionTests private static final List VERSIONS = List.of( "1", @@ -149,10 +184,12 @@ public class TopNOperatorTests extends OperatorTestCase { "1.2.3-rc1" ); - private final int pageSize = randomPageSize(); + protected int[] groupKeys() { + return new int[0]; + } @Override - protected TopNOperator.TopNOperatorFactory simple(SimpleOptions options) { + protected Operator.OperatorFactory simple(SimpleOptions options) { List elementTypes = List.of(LONG); List encoders = List.of(DEFAULT_UNSORTABLE); List sortOrders = List.of(new TopNOperator.SortOrder(0, true, false)); @@ -189,21 +226,7 @@ protected Matcher expectedToStringOfSimple() { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new SequenceLongBlockSourceOperator( - blockFactory, - LongStream.range(0, size).map(l -> ESTestCase.randomLong()), - between(1, size * 2) - ); - } - - protected SourceOperator simpleInput(BlockFactory blockFactory, int size, InputOrdering sortedInput) { - var longs = LongStream.range(0, size).map(l -> ESTestCase.randomLong()); - - return new SequenceLongBlockSourceOperator( - blockFactory, - sortedInput == InputOrdering.SORTED ? longs.sorted() : longs, - between(1, size * 2) - ); + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong()), between(1, size * 2)); } @Override @@ -229,54 +252,41 @@ protected void assertSimpleOutput(List input, List results) { ); } - public void testRamBytesUsed() { - for (var sortedInput : InputOrdering.values()) { - RamUsageTester.Accumulator acc = new RamUsageTester.Accumulator() { - @Override - public long accumulateObject(Object o, long shallowSize, Map fieldValues, Collection queue) { - if (o instanceof ElementType) { - return 0; // shared - } - if (o instanceof TopNEncoder) { - return 0; // shared - } - if (o instanceof CircuitBreaker) { - return 0; // shared - } - if (o instanceof BlockFactory) { - return 0; // shard - } - return super.accumulateObject(o, shallowSize, fieldValues, queue); - } - }; - int topCount = 10_000; - // We under-count by a few bytes because of the lists. In that end that's fine, but we need to account for it here. - long underCount = 200; - DriverContext context = driverContext(); - try ( - TopNOperator op = new TopNOperator.TopNOperatorFactory( - topCount, - List.of(LONG), - List.of(DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false)), - pageSize, - randomJumboPageBytes(), - sortedInput, - null - ).get(context) - ) { - long actualEmpty = RamUsageTester.ramUsed(op, acc); - assertThat(op.ramBytesUsed(), both(greaterThan(actualEmpty - underCount)).and(lessThan(actualEmpty))); - // But when we fill it then we're quite close - for (Page p : CannedSourceOperator.collectPages(simpleInput(context.blockFactory(), topCount, sortedInput))) { - op.addInput(p); - } - long actualFull = RamUsageTester.ramUsed(op, acc); - assertThat(op.ramBytesUsed(), both(greaterThan(actualFull - underCount)).and(lessThan(actualFull))); - - // TODO empty it again and check. - } + /** + * Creates the appropriate top-N operator factory for the test subclass. + */ + protected Operator.OperatorFactory createTopNOperatorFactory( + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys, + int maxPageSize, + long jumboPageBytes, + TopNOperator.InputOrdering inputOrdering, + @Nullable SharedMinCompetitive.Supplier minCompetitive + ) { + if (groupKeys.length > 0) { + return new GroupedTopNOperator.GroupedTopNOperatorFactory( + topCount, + elementTypes, + encoders, + sortOrders, + IntStream.of(groupKeys).boxed().toList(), + maxPageSize, + jumboPageBytes + ); } + return new TopNOperator.TopNOperatorFactory( + topCount, + elementTypes, + encoders, + sortOrders, + maxPageSize, + jumboPageBytes, + inputOrdering, + minCompetitive + ); } public void testRandomTopN() { @@ -295,11 +305,10 @@ public void testRandomTopNCranky() { } } - private void testRandomTopN(boolean asc, DriverContext context) { + protected void testRandomTopN(boolean asc, DriverContext context) { int limit = randomIntBetween(1, 20); - List inputValues = randomList(0, 5000, ESTestCase::randomLong); - Comparator comparator = asc ? naturalOrder() : reverseOrder(); - List expectedValues = inputValues.stream().sorted(comparator).limit(limit).toList(); + List inputValues = randomList(0, 5000, () -> randomLong()); + List expectedValues = inputValues.stream().sorted(asc ? naturalOrder() : reverseOrder()).limit(limit).toList(); List outputValues = topNLong(context, inputValues, limit, asc, false); assertThat(outputValues, equalTo(expectedValues)); } @@ -332,39 +341,44 @@ private void testTopNSortedInput( boolean asc, boolean nullsFirst ) { - List pages = new ArrayList<>(pagesValues.size()); + var layout = ChannelLayout.forDataColumns(1, groupKeys()); + List pages = new ArrayList<>(pagesValues.size()); for (var pageValues : pagesValues) { assert isSorted(pageValues, asc, nullsFirst); - List blocks = new ArrayList<>(pageValues.size()); - try (Block.Builder column = INT.newBlockBuilder(8, driverContext().blockFactory());) { + try (Block.Builder column = INT.newBlockBuilder(8, driverContext().blockFactory())) { for (var value : pageValues) { append(column, value); } - blocks.add(column.build()); + pages.add(new Page(layout.buildPageBlocks(pageValues.size(), driverContext().blockFactory(), column.build()))); } - pages.add(new Page(blocks.toArray(Block[]::new))); } List> actual = new ArrayList<>(); DriverContext driverContext = driverContext(); + List elementTypes = new ArrayList<>(); + List encoders = new ArrayList<>(); + for (int ch = 0; ch < layout.totalChannels(); ch++) { + elementTypes.add(INT); + encoders.add(layout.groupKeySet().contains(ch) ? DEFAULT_UNSORTABLE : DEFAULT_SORTABLE); + } + try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(pages.iterator()), List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topNCount, - List.of(INT), - List.of(DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, asc, nullsFirst)), - pageSize, + elementTypes, + encoders, + List.of(new TopNOperator.SortOrder(layout.dataChannel(0), asc, nullsFirst)), + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(p -> readInto(actual, p)) ) @@ -372,7 +386,7 @@ private void testTopNSortedInput( new TestDriverRunner().run(driver); } - assertThat(actual, equalTo(expectedResult)); + assertThat(layout.extractDataColumns(actual), equalTo(expectedResult)); } private void testTopNSortedInputWithTwoColumns( @@ -382,11 +396,10 @@ private void testTopNSortedInputWithTwoColumns( boolean asc, boolean nullsFirst ) { - List pages = new ArrayList<>(pagesValues.size()); + var layout = ChannelLayout.forDataColumns(2, groupKeys()); + List pages = new ArrayList<>(pagesValues.size()); for (var pageValues : pagesValues) { - List blocks = new ArrayList<>(pageValues.size()); - try ( Block.Builder firstColumn = INT.newBlockBuilder(8, driverContext().blockFactory()); Block.Builder secondColumn = INT.newBlockBuilder(8, driverContext().blockFactory()); @@ -395,31 +408,42 @@ private void testTopNSortedInputWithTwoColumns( append(firstColumn, value.v1()); append(secondColumn, value.v2()); } - blocks.add(firstColumn.build()); - blocks.add(secondColumn.build()); + pages.add( + new Page( + layout.buildPageBlocks(pageValues.size(), driverContext().blockFactory(), firstColumn.build(), secondColumn.build()) + ) + ); } - pages.add(new Page(blocks.toArray(Block[]::new))); } List> actual = new ArrayList<>(); DriverContext driverContext = driverContext(); + List elementTypes = new ArrayList<>(); + List encoders = new ArrayList<>(); + for (int ch = 0; ch < layout.totalChannels(); ch++) { + elementTypes.add(INT); + encoders.add(layout.groupKeySet().contains(ch) ? DEFAULT_UNSORTABLE : DEFAULT_SORTABLE); + } + try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(pages.iterator()), List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topNCount, - List.of(INT, INT), - List.of(DEFAULT_SORTABLE, DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, asc, nullsFirst), new TopNOperator.SortOrder(1, asc, nullsFirst)), - pageSize, + elementTypes, + encoders, + List.of( + new TopNOperator.SortOrder(layout.dataChannel(0), asc, nullsFirst), + new TopNOperator.SortOrder(layout.dataChannel(1), asc, nullsFirst) + ), + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(p -> readInto(actual, p)) ) @@ -427,24 +451,39 @@ private void testTopNSortedInputWithTwoColumns( new TestDriverRunner().run(driver); } - assertThat(actual, equalTo(expectedResult)); + assertThat(layout.extractDataColumns(actual), equalTo(expectedResult)); } public final void testTopNWithSortedInputToString() { - var topN = new TopNOperator.TopNOperatorFactory( + var topN = createTopNOperatorFactory( 4, List.of(LONG), List.of(DEFAULT_UNSORTABLE), List.of(new TopNOperator.SortOrder(0, true, false)), + groupKeys(), pageSize, randomJumboPageBytes(), InputOrdering.SORTED, null ); - var expectedDescription = "TopNOperator[count=0/4, elementTypes=[LONG], encoders=[DefaultUnsortable], " - + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], inputOrdering=SORTED]"; - try (Operator operator = topN.get(driverContext())) { - assertThat(operator.toString(), equalTo(expectedDescription)); + if (groupKeys().length > 0) { + var expectedDescription = "GroupedTopNOperator[count=0/0/4" + + ", elementTypes=[LONG], encoders=[DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]" + + ", groupKeys=" + + Arrays.toString(groupKeys()) + + "]"; + try (Operator operator = topN.get(driverContext())) { + assertThat(operator.toString(), equalTo(expectedDescription)); + } + } else { + var expectedDescription = "TopNOperator[count=0/4" + + ", elementTypes=[LONG], encoders=[DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]" + + ", inputOrdering=SORTED]"; + try (Operator operator = topN.get(driverContext())) { + assertThat(operator.toString(), equalTo(expectedDescription)); + } } } @@ -700,7 +739,11 @@ public void testBasicTopN() { assertThat(topNLong(values, 100, false, true), equalTo(Arrays.asList(null, null, 100L, 20L, 10L, 5L, 4L, 4L, 2L, 1L))); } - private List topNLong( + private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { + return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); + } + + protected List topNLong( DriverContext driverContext, List inputValues, int limit, @@ -712,7 +755,8 @@ private List topNLong( inputValues.stream().map(v -> tuple(v, 0L)).toList(), limit, List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(0, ascendingOrder, nullsFirst)) + List.of(new TopNOperator.SortOrder(0, ascendingOrder, nullsFirst)), + groupKeys() ).stream().map(Tuple::v1).toList(); } @@ -720,10 +764,6 @@ private static TupleLongLongBlockSourceOperator longLongSourceOperator(DriverCon return new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)); } - private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { - return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); - } - public void testCompareInts() { BlockFactory blockFactory = blockFactory(); testCompare( @@ -901,7 +941,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, false)) + List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, false)), + groupKeys() ), equalTo(List.of(tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null), tuple(null, 1L), tuple(null, null))) ); @@ -911,7 +952,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, true), new TopNOperator.SortOrder(1, true, false)) + List.of(new TopNOperator.SortOrder(0, true, true), new TopNOperator.SortOrder(1, true, false)), + groupKeys() ), equalTo(List.of(tuple(null, 1L), tuple(null, null), tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null))) ); @@ -921,7 +963,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, true)) + List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, true)), + groupKeys() ), equalTo(List.of(tuple(1L, null), tuple(1L, 1L), tuple(1L, 2L), tuple(null, null), tuple(null, 1L))) ); @@ -934,13 +977,11 @@ public void testCollectAllValues() { int size = 10; int topCount = 3; List blocks = new ArrayList<>(); - List> expectedTop = new ArrayList<>(); + List> rawValues = new ArrayList<>(); IntBlock keys = blockFactory.newIntArrayVector(IntStream.range(0, size).toArray(), size).asBlock(); - List topKeys = new ArrayList<>(IntStream.range(size - topCount, size).boxed().toList()); - Collections.reverse(topKeys); - expectedTop.add(topKeys); blocks.add(keys); + rawValues.add(IntStream.range(0, size).mapToObj(Integer::valueOf).toList()); List elementTypes = new ArrayList<>(); List encoders = new ArrayList<>(); @@ -955,39 +996,36 @@ public void testCollectAllValues() { } elementTypes.add(e); encoders.add(nonKeyEncoder(e)); - List eTop = new ArrayList<>(); try (Block.Builder builder = e.newBlockBuilder(size, driverContext().blockFactory())) { + var rawValuesForElement = new ArrayList<>(size); for (int i = 0; i < size; i++) { Object value = randomValue(e); append(builder, value); - if (i >= size - topCount) { - eTop.add(value); - } + rawValuesForElement.add(value); } - Collections.reverse(eTop); blocks.add(builder.build()); - expectedTop.add(eTop); + rawValues.add(rawValuesForElement); } } List> actualTop = new ArrayList<>(); + List sortOrders = List.of(new TopNOperator.SortOrder(0, false, false)); try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of( - new TopNOperator( - blockFactory, - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topCount, elementTypes, encoders, - List.of(new TopNOperator.SortOrder(0, false, false)), - pageSize, + sortOrders, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(page -> readInto(actualTop, page)) ) @@ -995,10 +1033,55 @@ public void testCollectAllValues() { new TestDriverRunner().run(driver); } - assertMap(actualTop, matchesList(expectedTop)); + assertMap(actualTop, matchesList(expectedTop(rawValues, sortOrders, topCount))); assertDriverContext(driverContext); } + protected List> expectedTop(List> input, List sortOrders, int topCount) { + // input is channel-oriented, transpose to row-oriented for processing and then back format. + return transpose(expectedTopRowOriented(transpose(input), sortOrders, topCount)); + } + + protected List> expectedTopRowOriented( + List> rowOriented, + List sortOrders, + int topCount + ) { + Comparator> comparator = (row1, row2) -> { + for (TopNOperator.SortOrder order : sortOrders) { + Object v1 = row1.get(order.channel()); + Object v2 = row2.get(order.channel()); + boolean firstIsNull = v1 == null; + boolean secondIsNull = v2 == null; + + if (firstIsNull || secondIsNull) { + int nullCompare = Boolean.compare(firstIsNull, secondIsNull) * (order.nullsFirst() ? -1 : 1); + if (nullCompare != 0) { + return nullCompare; + } + continue; + } + + @SuppressWarnings("unchecked") + int cmp = ((Comparable) v1).compareTo(v2); + if (cmp != 0) { + return order.asc() ? cmp : -cmp; + } + } + return 0; + }; + + return rowOriented.stream().sorted(comparator).limit(topCount).toList(); + } + + private static List> transpose(List> input) { + if (input.isEmpty()) { + return new ArrayList<>(); + } + int numRows = input.getFirst().size(); + return IntStream.range(0, numRows).mapToObj(row -> input.stream().map(channel -> channel.get(row)).toList()).toList(); + } + public void testCollectAllValues_RandomMultiValues() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); @@ -1007,13 +1090,11 @@ public void testCollectAllValues_RandomMultiValues() { int topCount = 3; int blocksCount = 20; List blocks = new ArrayList<>(); - List> expectedTop = new ArrayList<>(); + List> rawValues = new ArrayList<>(); IntBlock keys = blockFactory.newIntArrayVector(IntStream.range(0, rows).toArray(), rows).asBlock(); - List topKeys = new ArrayList<>(IntStream.range(rows - topCount, rows).boxed().toList()); - Collections.reverse(topKeys); - expectedTop.add(topKeys); blocks.add(keys); + rawValues.add(IntStream.range(0, rows).mapToObj(Integer::valueOf).toList()); List elementTypes = new ArrayList<>(blocksCount); List encoders = new ArrayList<>(blocksCount); @@ -1034,57 +1115,50 @@ public void testCollectAllValues_RandomMultiValues() { } elementTypes.add(e); encoders.add(nonKeyEncoder(e)); - List eTop = new ArrayList<>(); + List channelValues = new ArrayList<>(); try (Block.Builder builder = e.newBlockBuilder(rows, driverContext().blockFactory())) { for (int i = 0; i < rows; i++) { if (e != ElementType.DOC && e != ElementType.NULL && randomBoolean()) { // generate a multi-value block int mvCount = randomIntBetween(5, 10); - List eTopList = new ArrayList<>(mvCount); + List mvValues = new ArrayList<>(mvCount); builder.beginPositionEntry(); for (int j = 0; j < mvCount; j++) { Object value = randomValue(e); append(builder, value); - if (i >= rows - topCount) { - eTopList.add(value); - } + mvValues.add(value); } builder.endPositionEntry(); - if (i >= rows - topCount) { - eTop.add(eTopList); - } + channelValues.add(mvValues); } else { Object value = randomValue(e); append(builder, value); - if (i >= rows - topCount) { - eTop.add(value); - } + channelValues.add(value); } } - Collections.reverse(eTop); blocks.add(builder.build()); - expectedTop.add(eTop); + rawValues.add(channelValues); } } List> actualTop = new ArrayList<>(); + List sortOrders = List.of(new TopNOperator.SortOrder(0, false, false)); try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of( - new TopNOperator( - blockFactory, - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topCount, elementTypes, encoders, - List.of(new TopNOperator.SortOrder(0, false, false)), - pageSize, + sortOrders, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(page -> readInto(actualTop, page)) ) @@ -1092,7 +1166,7 @@ public void testCollectAllValues_RandomMultiValues() { new TestDriverRunner().run(driver); } - assertMap(actualTop, matchesList(expectedTop)); + assertMap(actualTop, matchesList(expectedTop(rawValues, sortOrders, topCount))); assertDriverContext(driverContext); } @@ -1104,37 +1178,36 @@ private static TopNEncoder nonKeyEncoder(ElementType elementType) { }; } - private List> topNTwoLongColumns( + protected List> topNTwoLongColumns( DriverContext driverContext, List> values, int limit, List encoder, - List sortOrders + List sortOrders, + int[] groupKeys ) { - var pages = topNTwoColumns( + var pages = topNMultipleColumns( driverContext, new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)), - AlwaysReferencedIndexedByShardId.INSTANCE, limit, encoder, - sortOrders + sortOrders, + groupKeys ); - var result = pageToTuples( + return pageToTuples( (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), pages ); - assertThat(result, hasSize(Math.min(limit, values.size()))); - return result; } - private List topNTwoColumns( + protected List topNMultipleColumns( DriverContext driverContext, - TupleAbstractBlockSourceOperator sourceOperator, - IndexedByShardId shardRefCounters, + TypedAbstractBlockSourceBuilder sourceOperator, int limit, List encoder, - List sortOrders + List sortOrders, + int[] groupKeys ) { var pages = new ArrayList(); boolean success = false; @@ -1144,18 +1217,17 @@ private List topNTwoColumns( driverContext, sourceOperator, List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( limit, sourceOperator.elementTypes(), encoder, sortOrders, - pageSize, + groupKeys, + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(pages::add) ) @@ -1172,7 +1244,7 @@ private List topNTwoColumns( return pages; } - private static List> pageToTuples( + protected static List> pageToTuples( BiFunction getFirstBlockValue, BiFunction getSecondBlockValue, List pages @@ -1197,12 +1269,13 @@ private static List> pageToTuples( public void testTopNManyDescriptionAndToString() { int fixedLength = between(1, 100); - TopNOperator.TopNOperatorFactory factory = new TopNOperator.TopNOperatorFactory( + Operator.OperatorFactory factory = createTopNOperatorFactory( 10, List.of(BYTES_REF, BYTES_REF), List.of(UTF8, new FixedLengthAscTopNEncoder(fixedLength)), List.of(new TopNOperator.SortOrder(0, false, false), new TopNOperator.SortOrder(1, false, true)), - pageSize, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null @@ -1210,14 +1283,28 @@ public void testTopNManyDescriptionAndToString() { String sorts = List.of("SortOrder[channel=0, asc=false, nullsFirst=false]", "SortOrder[channel=1, asc=false, nullsFirst=true]") .stream() .collect(Collectors.joining(", ")); - String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" - + fixedLength - + "]], sortOrders=[" - + sorts - + "], inputOrdering=NOT_SORTED]"; - assertThat(factory.describe(), equalTo("TopNOperator[count=10" + tail)); - try (Operator operator = factory.get(driverContext())) { - assertThat(operator.toString(), equalTo("TopNOperator[count=0/10" + tail)); + if (groupKeys().length > 0) { + String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" + + fixedLength + + "]], sortOrders=[" + + sorts + + "], groupKeys=" + + Arrays.toString(groupKeys()) + + "]"; + assertThat(factory.describe(), equalTo("GroupedTopNOperator[count=10" + tail)); + try (Operator operator = factory.get(driverContext())) { + assertThat(operator.toString(), equalTo("GroupedTopNOperator[count=0/0/10" + tail)); + } + } else { + String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" + + fixedLength + + "]], sortOrders=[" + + sorts + + "], inputOrdering=NOT_SORTED]"; + assertThat(factory.describe(), equalTo("TopNOperator[count=10" + tail)); + try (Operator operator = factory.get(driverContext())) { + assertThat(operator.toString(), equalTo("TopNOperator[count=0/10" + tail)); + } } } @@ -1469,12 +1556,10 @@ public void testRandomMultiValuesTopN() { int blocksCount = between(20, 30); int sortingByColumns = between(1, 10); - Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); - List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); - List> randomValueSuppliers = new ArrayList<>(blocksCount); boolean[] validSortKeys = new boolean[blocksCount]; List elementTypes = new ArrayList<>(blocksCount); List encoders = new ArrayList<>(blocksCount); + List> randomValueSuppliers = new ArrayList<>(blocksCount); for (int type = 0; type < blocksCount; type++) { ElementType e = randomValueOtherThanMany( @@ -1520,6 +1605,13 @@ public void testRandomMultiValuesTopN() { randomValueSuppliers.add(randomValueSupplier); } + int[] gk = groupKeys(); + var layout = ChannelLayout.forDataColumns(blocksCount, gk); + if (gk.length > 0) { + layout.insertGroupKeyEntries(elementTypes, INT); + layout.insertGroupKeyEntries(encoders, DEFAULT_UNSORTABLE); + } + /* * Build sort keys, making sure not to include duplicates. This could * build fewer than the desired sort columns, but it's more important @@ -1528,30 +1620,27 @@ public void testRandomMultiValuesTopN() { * not to include sort keys that simulate geo objects. Those aren't * sortable at all. */ + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); for (int i = 0; i < sortingByColumns; i++) { int column = randomValueOtherThanMany(c -> false == validSortKeys[c], () -> randomIntBetween(0, blocksCount - 1)); - uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + uniqueOrders.add(new TopNOperator.SortOrder(layout.dataChannel(column), randomBoolean(), randomBoolean())); } List sortOrders = uniqueOrders.stream().toList(); - NaiveTopNComparator comparator = new NaiveTopNComparator(sortOrders); - SharedMinCompetitive.Supplier minCompetitiveSupplier = randomBoolean() - ? new SharedMinCompetitive.Supplier(blockFactory().breaker(), keyConfigs(elementTypes, encoders, sortOrders)) - : null; - SharedMinCompetitive minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); + List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); + List>> actualValues = new ArrayList<>(); try ( - TopNOperator operator = new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + Operator operator = createTopNOperatorFactory( topCount, elementTypes, encoders, sortOrders, + gk, rowsPerPage, Long.MAX_VALUE, InputOrdering.NOT_SORTED, - minCompetitiveSupplier - ) + null + ).get(driverContext) ) { for (int p = 0; p < pageCount; p++) { assertThat(operator.needsInput(), equalTo(true)); @@ -1559,40 +1648,30 @@ public void testRandomMultiValuesTopN() { assertThat(operator.getOutput(), nullValue()); for (int r = 0; r < rowsPerPage; r++) { - expectedValues.add(new ArrayList<>(blocksCount)); + expectedValues.add(new ArrayList<>(blocksCount + gk.length)); } - Block[] blocks = new Block[blocksCount]; + Block[] dataBlocks = new Block[blocksCount]; for (int b = 0; b < blocksCount; b++) { - ElementType elementType = elementTypes.get(b); + ElementType elementType = elementTypes.get(layout.dataChannel(b)); try (Block.Builder builder = elementType.newBlockBuilder(rowsPerPage, driverContext().blockFactory())) { List previousValue = null; - for (int r = 0; r < rowsPerPage; r++) { List values = new ArrayList<>(); - // let's make things a bit more real for this TopN sorting: have some "equal" values in different rows for the - // same - // block if (rarely() && previousValue != null) { values = previousValue; } else { if (elementType != ElementType.NULL && randomBoolean()) { - // generate a multi-value block int mvCount = randomIntBetween(5, 10); for (int j = 0; j < mvCount; j++) { - Object value = randomValueSuppliers.get(b).get(); - values.add(value); + values.add(randomValueSuppliers.get(b).get()); } - } else {// null or single-valued value - Object value = randomValueSuppliers.get(b).get(); - values.add(value); + } else { + values.add(randomValueSuppliers.get(b).get()); } - if (usually() && randomBoolean()) { - // let's remember the "previous" value, maybe we'll use it again in a different row previousValue = values; } } - if (values.size() == 1) { append(builder, values.get(0)); } else { @@ -1602,59 +1681,45 @@ public void testRandomMultiValuesTopN() { } builder.endPositionEntry(); } - expectedValues.get(p * rowsPerPage + r).add(values); } - blocks[b] = builder.build(); + dataBlocks[b] = builder.build(); } } - operator.addInput(new Page(blocks)); - - if (minCompetitive != null) { - if ((p + 1) * rowsPerPage < topCount) { - assertThat(minCompetitive.get(blockFactory()), nullValue()); - } else { - List> minCompetitiveRow = expectedValues.stream() - .sorted(comparator) - .skip(topCount - 1) - .findFirst() - .get(); - try (Page min = minCompetitive.get(blockFactory())) { - assertThat(min.getBlockCount(), equalTo(sortOrders.size())); - for (int s = 0; s < min.getBlockCount(); s++) { - logger.info("checking key {}", s); - TopNOperator.SortOrder sort = sortOrders.get(s); - Object actual = BlockUtils.toJavaObject(min.getBlock(s), 0); - Object expected = reduceKey(minCompetitiveRow.get(sort.channel()), sort.asc()); - assertThat(actual, equalTo(expected)); - } - } + if (gk.length > 0) { + for (int r = 0; r < rowsPerPage; r++) { + layout.insertGroupKeyEntries(expectedValues.get(p * rowsPerPage + r), List.of(0)); } } + Block[] pageBlocks = gk.length > 0 + ? layout.buildPageBlocks(rowsPerPage, driverContext.blockFactory(), dataBlocks) + : dataBlocks; + operator.addInput(new Page(pageBlocks)); } operator.finish(); assertThat(operator.needsInput(), equalTo(false)); assertThat(operator.isFinished(), equalTo(false)); - - List>> actualValues = new ArrayList<>(); while (operator.isFinished() == false) { - try (Page p = operator.getOutput()) { + Page p = operator.getOutput(); + assertThat(operator.needsInput(), equalTo(false)); + if (p != null) { readAsRows(actualValues, p); - assertThat(operator.needsInput(), equalTo(false)); + p.releaseBlocks(); } } + } - List>> topNExpectedValues = expectedValues.stream().sorted(comparator).limit(topCount).toList(); - List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); - List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); + List>> topNExpectedValues = expectedValues.stream() + .sorted(new NaiveTopNComparator(sortOrders)) + .limit(topCount) + .toList(); + List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); + List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); - assertMap(actualReducedValues, matchesList(expectedReducedValues)); - } finally { - Releasables.close(minCompetitive); - } + assertMap(actualReducedValues, matchesList(expectedReducedValues)); } - private List keyConfigs( + protected List keyConfigs( List elementTypes, List encoders, List sortOrders @@ -2043,56 +2108,8 @@ public void testRowResizes() { } } - public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { - topNShardContextManagementAux(4, Stream.generate(() -> true).limit(4).toList()); - } - - public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { - topNShardContextManagementAux(2, List.of(true, false, false, true)); - } - - private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { - List> values = Arrays.asList( - tuple(new BlockUtils.Doc(0, 10, 100), 1L), - tuple(new BlockUtils.Doc(1, 20, 200), 2L), - tuple(new BlockUtils.Doc(2, 30, 300), null), - tuple(new BlockUtils.Doc(3, 40, 400), -3L) - ); - - List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); - var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); - var pages = topNTwoColumns(driverContext(), new TupleDocLongBlockSourceOperator(driverContext().blockFactory(), values) { - @Override - protected Block.Builder firstElementBlockBuilder(int length) { - return DocBlock.newBlockBuilder(blockFactory, length).shardRefCounters(shardRefCounters); - } - }, - shardRefCounters, - limit, - List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(1, true, false)) - ); - refCountedList.forEach(RefCounted::decRef); - - assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); - - var expectedValues = values.stream() - .sorted(Comparator.comparingLong(t -> t.v2() == null ? Long.MAX_VALUE : t.v2())) - .limit(limit) - .toList(); - assertThat( - pageToTuples((b, i) -> (BlockUtils.Doc) BlockUtils.toJavaObject(b, i), (b, i) -> ((LongBlock) b).getLong(i), pages), - equalTo(expectedValues) - ); - Releasables.close(pages); - - for (var rc : refCountedList) { - assertFalse(rc.hasReferences()); - } - } - @SuppressWarnings({ "unchecked", "rawtypes" }) - private static void readAsRows(List>> values, Page page) { + protected static void readAsRows(List>> values, Page page) { if (page.getBlockCount() == 0) { fail("No blocks returned!"); } @@ -2116,6 +2133,13 @@ private static void readAsRows(List>> values, Page page) { } } + protected static List> readAsRowsSingleValue(Page page) { + assertThat(page.getBlockCount(), greaterThan(0)); + return IntStream.range(0, page.getPositionCount()) + .mapToObj(position -> IntStream.range(0, page.getBlockCount()).mapToObj(i -> toJavaObject(page.getBlock(i), position)).toList()) + .toList(); + } + public void testSplitOnSize() { int topCount = 50; long jumboPageBytes = randomJumboPageBytes(); @@ -2142,19 +2166,20 @@ public void testSplitOnSize() { */ maxPageSize += PageCacheRecycler.PAGE_SIZE_IN_BYTES; } + int[] gk = groupKeys(); + int expectedTotal = gk.length > 0 ? inputPageRows * inputPageCount : topCount; try ( - TopNOperator op = new TopNOperator( - driverContext().blockFactory(), - factory.breaker(), + Operator op = createTopNOperatorFactory( topCount, List.of(BYTES_REF), List.of(UTF8), List.of(new TopNOperator.SortOrder(0, randomBoolean(), randomBoolean())), + gk, Integer.MAX_VALUE, jumboPageBytes, InputOrdering.NOT_SORTED, null - ) + ).get(driverContext()) ) { for (int p = 0; p < inputPageCount; p++) { try (BytesRefBlock.Builder bytes = factory.newBytesRefBlockBuilder(inputPageRows)) { @@ -2169,14 +2194,14 @@ public void testSplitOnSize() { while (op.isFinished() == false) { try (Page out = op.getOutput()) { totalPositions += out.getPositionCount(); - if (totalPositions < topCount) { + if (totalPositions < expectedTotal) { assertThat(out.ramBytesUsedByBlocks(), both(greaterThanOrEqualTo(minPageSize)).and(lessThanOrEqualTo(maxPageSize))); } else { assertThat(out.ramBytesUsedByBlocks(), lessThanOrEqualTo(maxPageSize)); } } } - assertThat(totalPositions, equalTo(topCount)); + assertThat(totalPositions, equalTo(expectedTotal)); } } @@ -2191,7 +2216,7 @@ public void testSplitOnSize() { * but it's as close as possible to a very general and fully randomized unit test for TopNOperator with multi-values support. */ @SuppressWarnings({ "unchecked", "rawtypes" }) - private List> extractAndReduceSortedValues(List>> rows, Set orders) { + protected List> extractAndReduceSortedValues(List>> rows, Set orders) { List> result = new ArrayList<>(rows.size()); for (List> row : rows) { @@ -2218,7 +2243,7 @@ static Object minMax(List values, boolean asc) { : values.stream().map(element -> (Comparable) element).max(naturalOrder()).get(); } - private class NaiveTopNComparator implements Comparator>> { + protected class NaiveTopNComparator implements Comparator>> { private final List orders; NaiveTopNComparator(List orders) { @@ -2254,6 +2279,562 @@ static Version randomVersion() { return new Version(randomFrom(VERSIONS)); } + protected static class RandomMultiValueBlocksResult { + final List>> expectedValues; + final List blocks; + final boolean[] validSortKeys; + final List elementTypes; + final List encoders; + + RandomMultiValueBlocksResult( + List>> expectedValues, + List blocks, + boolean[] validSortKeys, + List elementTypes, + List encoders + ) { + this.expectedValues = expectedValues; + this.blocks = blocks; + this.validSortKeys = validSortKeys; + this.elementTypes = elementTypes; + this.encoders = encoders; + } + } + + protected RandomMultiValueBlocksResult generateRandomMultiValueBlocks(int rows, int blocksCount, DriverContext driverContext) { + List>> expectedValues = new ArrayList<>(rows); + List blocks = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int i = 0; i < rows; i++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == COMPOSITE + || t == AGGREGATE_METRIC_DOUBLE + || t == EXPONENTIAL_HISTOGRAM + || t == TDIGEST + || t == LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + try (Block.Builder builder = e.newBlockBuilder(rows, driverContext.blockFactory())) { + List previousValue = null; + Function randomValueSupplier = (blockType) -> randomValue(blockType); + if (e == BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + encoders.add(TopNEncoder.IP); + yield (blockType) -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + encoders.add(TopNEncoder.VERSION); + yield (blockType) -> randomVersion().toBytesRef(); + } + case 2 -> { + encoders.add(DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield (blockType) -> randomPointAsWKB(); + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(UTF8); + } + } else { + encoders.add(DEFAULT_SORTABLE); + } + + for (int i = 0; i < rows; i++) { + List values = new ArrayList<>(); + if (rarely() && previousValue != null) { + values = previousValue; + } else { + if (e != ElementType.NULL && randomBoolean()) { + int mvCount = randomIntBetween(5, 10); + for (int j = 0; j < mvCount; j++) { + Object value = randomValueSupplier.apply(e); + values.add(value); + } + } else { + Object value = randomValueSupplier.apply(e); + values.add(value); + } + + if (usually() && randomBoolean()) { + previousValue = values; + } + } + + if (values.size() == 1) { + append(builder, values.get(0)); + } else { + builder.beginPositionEntry(); + for (Object o : values) { + append(builder, o); + } + builder.endPositionEntry(); + } + + expectedValues.get(i).add(values); + } + blocks.add(builder.build()); + } + } + + return new RandomMultiValueBlocksResult(expectedValues, blocks, validSortKeys, elementTypes, encoders); + } + + protected Set generateSortOrders( + int sortingByColumns, + int blocksCount, + boolean[] validSortKeys, + java.util.function.Predicate excludeColumn + ) { + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); + for (int i = 0; i < sortingByColumns; i++) { + int column = randomValueOtherThanMany( + c -> false == validSortKeys[c] || excludeColumn.test(c), + () -> randomIntBetween(0, blocksCount - 1) + ); + uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + } + return uniqueOrders; + } + + protected static class RandomBlocksResult { + final List> expectedValues; + final List blocks; + final boolean[] validSortKeys; + final List elementTypes; + final List encoders; + + RandomBlocksResult( + List> expectedValues, + List blocks, + boolean[] validSortKeys, + List elementTypes, + List encoders + ) { + this.expectedValues = expectedValues; + this.blocks = blocks; + this.validSortKeys = validSortKeys; + this.elementTypes = elementTypes; + this.encoders = encoders; + } + } + + protected RandomBlocksResult generateRandomSingleValueBlocks(int rows, int blocksCount, DriverContext driverContext) { + List> expectedValues = new ArrayList<>(rows); + List blocks = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int i = 0; i < rows; i++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == COMPOSITE + || t == AGGREGATE_METRIC_DOUBLE + || t == EXPONENTIAL_HISTOGRAM + || t == TDIGEST + || t == LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + try (Block.Builder builder = e.newBlockBuilder(rows, driverContext.blockFactory())) { + Function randomValueSupplier = (blockType) -> randomValue(blockType); + if (e == BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + encoders.add(TopNEncoder.IP); + yield (blockType) -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + encoders.add(TopNEncoder.VERSION); + yield (blockType) -> randomVersion().toBytesRef(); + } + case 2 -> { + encoders.add(DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield (blockType) -> randomPointAsWKB(); + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(UTF8); + } + } else { + encoders.add(DEFAULT_SORTABLE); + } + + for (int i = 0; i < rows; i++) { + Object value = randomValueSupplier.apply(e); + append(builder, value); + expectedValues.get(i).add(value); + } + blocks.add(builder.build()); + } + } + + return new RandomBlocksResult(expectedValues, blocks, validSortKeys, elementTypes, encoders); + } + + /** + * Maps group key channels and data channels for tests that need to add constant-value + * group key columns alongside their data columns. When {@code groupKeys} is empty, the + * layout is an identity mapping (data channel i == channel i). + */ + protected record ChannelLayout(int[] groupKeys, Set groupKeySet, int[] dataChannels, int totalChannels) { + + static ChannelLayout forDataColumns(int dataColumnCount, int[] groupKeys) { + int totalChannels = dataColumnCount + groupKeys.length; + Set groupKeySet = IntStream.of(groupKeys).boxed().collect(Collectors.toSet()); + int[] dataChannels = IntStream.range(0, totalChannels).filter(ch -> groupKeySet.contains(ch) == false).toArray(); + return new ChannelLayout(groupKeys, groupKeySet, dataChannels, totalChannels); + } + + /** Returns the actual channel index for the given data column index. */ + int dataChannel(int dataIndex) { + return dataChannels[dataIndex]; + } + + /** + * Builds a {@code Block[]} with constant-0 INT blocks at the group key channels + * and the supplied data blocks placed at the corresponding data channels. + */ + Block[] buildPageBlocks(int positions, BlockFactory blockFactory, Block... dataBlocks) { + Block[] blockArray = new Block[totalChannels]; + for (int g : groupKeys) { + try (Block.Builder groupCol = INT.newBlockBuilder(positions, blockFactory)) { + for (int j = 0; j < positions; j++) { + append(groupCol, 0); + } + blockArray[g] = groupCol.build(); + } + } + for (int d = 0; d < dataBlocks.length; d++) { + blockArray[dataChannels[d]] = dataBlocks[d]; + } + return blockArray; + } + + /** Extracts only the data-channel columns from column-oriented results. */ + List> extractDataColumns(List> actual) { + if (groupKeys.length == 0 || actual.isEmpty()) { + return actual; + } + return IntStream.of(dataChannels).mapToObj(actual::get).toList(); + } + + /** Inserts group key entries into pre-built lists at the correct positions. */ + void insertGroupKeyEntries(List list, T value) { + int[] sortedGk = IntStream.of(groupKeys).sorted().toArray(); + for (int g : sortedGk) { + list.add(g, value); + } + } + } + + public void testRamBytesUsed() { + int topCount = 10_000; + long underCount = groupKeys().length > 0 ? 1000 : 250; + DriverContext context = driverContext(); + int numColumns = 1 + groupKeys().length; + List elementTypes = Collections.nCopies(numColumns, LONG); + List encoders = Collections.nCopies(numColumns, DEFAULT_UNSORTABLE); + try ( + Operator op = createTopNOperatorFactory( + topCount, + elementTypes, + encoders, + List.of(new TopNOperator.SortOrder(0, true, false)), + groupKeys(), + pageSize, + Long.MAX_VALUE, + TopNOperator.InputOrdering.NOT_SORTED, + null + ).get(context) + ) { + Accountable accountable = (Accountable) op; + long actualEmpty = RamUsageTester.ramUsed(op, RAM_USAGE_ACCUMULATOR); + assertThat(accountable.ramBytesUsed(), both(greaterThan(actualEmpty - underCount)).and(lessThan(actualEmpty))); + for (Page p : CannedSourceOperator.collectPages(simpleInput(context.blockFactory(), topCount))) { + op.addInput(p); + } + long actualFull = RamUsageTester.ramUsed(op, RAM_USAGE_ACCUMULATOR); + assertThat(accountable.ramBytesUsed(), both(greaterThan(actualFull - underCount)).and(lessThan(actualFull))); + } + } + + public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + topNShardContextManagementAux(4, Stream.generate(() -> true).limit(4).toList()); + } + + public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + topNShardContextManagementAux(2, List.of(true, false, false, true)); + } + + private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { + List> values = Arrays.asList( + tuple(new BlockUtils.Doc(0, 10, 100), 1L), + tuple(new BlockUtils.Doc(1, 20, 200), 2L), + tuple(new BlockUtils.Doc(2, 30, 300), null), + tuple(new BlockUtils.Doc(3, 40, 400), -3L) + ); + + List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); + var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); + var pages = topNMultipleColumns(driverContext(), new TupleDocLongBlockSourceOperator(driverContext().blockFactory(), values) { + @Override + protected Block.Builder firstElementBlockBuilder(int length) { + return DocBlock.newBlockBuilder(blockFactory, length).shardRefCounters(shardRefCounters); + } + }, + limit, + List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE), + List.of(new TopNOperator.SortOrder(1, true, false)), + groupKeys() + ); + refCountedList.forEach(RefCounted::decRef); + + assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); + + var expectedValues = values.stream() + .sorted(Comparator.comparingLong(t -> t.v2() == null ? Long.MAX_VALUE : t.v2())) + .limit(limit) + .toList(); + assertThat( + pageToTuples((b, i) -> (BlockUtils.Doc) BlockUtils.toJavaObject(b, i), (b, i) -> ((LongBlock) b).getLong(i), pages), + equalTo(expectedValues) + ); + Releasables.close(pages); + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } + } + + public void testRandomMultiValuesTopNWithMinCompetitive() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + DriverContext driverContext = driverContext(); + int pageCount = between(1, 100); + int rowsPerPage = between(50, 100); + int topCount = between(1, rowsPerPage * pageCount); + int blocksCount = between(20, 30); + int sortingByColumns = between(1, 10); + + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); + List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); + List> randomValueSuppliers = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == ElementType.COMPOSITE + || t == ElementType.AGGREGATE_METRIC_DOUBLE + || t == ElementType.EXPONENTIAL_HISTOGRAM + || t == ElementType.TDIGEST + || t == ElementType.LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + Supplier randomValueSupplier = () -> randomValue(e); + if (e == ElementType.BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + // Simulate ips + encoders.add(TopNEncoder.IP); + yield () -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + // Simulate version fields + encoders.add(TopNEncoder.VERSION); + yield () -> randomVersion().toBytesRef(); + } + case 2 -> { + // Simulate geo_shape and geo_point + encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield TopNEncoderTests::randomPointAsWKB; + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(TopNEncoder.UTF8); + } + } else { + encoders.add(TopNEncoder.DEFAULT_SORTABLE); + } + randomValueSuppliers.add(randomValueSupplier); + } + + for (int i = 0; i < sortingByColumns; i++) { + int column = randomValueOtherThanMany(c -> false == validSortKeys[c], () -> randomIntBetween(0, blocksCount - 1)); + uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + } + List sortOrders = uniqueOrders.stream().toList(); + NaiveTopNComparator comparator = new NaiveTopNComparator(sortOrders); + SharedMinCompetitive.Supplier minCompetitiveSupplier = randomBoolean() + ? new SharedMinCompetitive.Supplier(blockFactory().breaker(), keyConfigs(elementTypes, encoders, sortOrders)) + : null; + SharedMinCompetitive minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); + + try ( + TopNOperator operator = new TopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + rowsPerPage, + Long.MAX_VALUE, + TopNOperator.InputOrdering.NOT_SORTED, + minCompetitiveSupplier + ) + ) { + for (int p = 0; p < pageCount; p++) { + assertThat(operator.needsInput(), equalTo(true)); + assertThat(operator.isFinished(), equalTo(false)); + assertThat(operator.getOutput(), nullValue()); + + for (int r = 0; r < rowsPerPage; r++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + Block[] blocks = new Block[blocksCount]; + for (int b = 0; b < blocksCount; b++) { + ElementType elementType = elementTypes.get(b); + try (Block.Builder builder = elementType.newBlockBuilder(rowsPerPage, driverContext().blockFactory())) { + List previousValue = null; + for (int r = 0; r < rowsPerPage; r++) { + List values = new ArrayList<>(); + if (rarely() && previousValue != null) { + values = previousValue; + } else { + if (elementType != ElementType.NULL && randomBoolean()) { + int mvCount = randomIntBetween(5, 10); + for (int j = 0; j < mvCount; j++) { + values.add(randomValueSuppliers.get(b).get()); + } + } else { + values.add(randomValueSuppliers.get(b).get()); + } + if (usually() && randomBoolean()) { + previousValue = values; + } + } + if (values.size() == 1) { + append(builder, values.get(0)); + } else { + builder.beginPositionEntry(); + for (Object o : values) { + append(builder, o); + } + builder.endPositionEntry(); + } + expectedValues.get(p * rowsPerPage + r).add(values); + } + blocks[b] = builder.build(); + } + } + operator.addInput(new Page(blocks)); + + if (minCompetitive != null) { + if ((p + 1) * rowsPerPage < topCount) { + assertThat(minCompetitive.get(blockFactory()), nullValue()); + } else { + List> minCompetitiveRow = expectedValues.stream() + .sorted(comparator) + .skip(topCount - 1) + .findFirst() + .get(); + try (Page min = minCompetitive.get(blockFactory())) { + assertThat(min.getBlockCount(), equalTo(sortOrders.size())); + for (int s = 0; s < min.getBlockCount(); s++) { + logger.info("checking key {}", s); + TopNOperator.SortOrder sort = sortOrders.get(s); + Object actual = BlockUtils.toJavaObject(min.getBlock(s), 0); + Object expected = reduceKey(minCompetitiveRow.get(sort.channel()), sort.asc()); + assertThat(actual, equalTo(expected)); + } + } + } + } + } + operator.finish(); + assertThat(operator.needsInput(), equalTo(false)); + assertThat(operator.isFinished(), equalTo(false)); + + List>> actualValues = new ArrayList<>(); + while (operator.isFinished() == false) { + try (Page p = operator.getOutput()) { + readAsRows(actualValues, p); + assertThat(operator.needsInput(), equalTo(false)); + } + } + + List>> topNExpectedValues = expectedValues.stream().sorted(comparator).limit(topCount).toList(); + List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); + List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); + + assertMap(actualReducedValues, matchesList(expectedReducedValues)); + } finally { + Releasables.close(minCompetitive); + } + } + + public void testStatus() { + BlockFactory blockFactory = driverContext().blockFactory(); + try (Operator op = simple(SimpleOptions.DEFAULT).get(driverContext())) { + Operator.Status status = op.status(); + assertThat(status, instanceOf(TopNOperatorStatus.class)); + TopNOperatorStatus topNStatus = (TopNOperatorStatus) status; + assertThat(topNStatus.occupiedRows(), equalTo(0)); + assertThat(topNStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(topNStatus.pagesReceived(), equalTo(0)); + assertThat(topNStatus.pagesEmitted(), equalTo(0)); + assertThat(topNStatus.rowsReceived(), equalTo(0L)); + assertThat(topNStatus.rowsEmitted(), equalTo(0L)); + + Page p = new Page(blockFactory.newConstantLongBlockWith(1, 10)); + op.addInput(p); + status = op.status(); + topNStatus = (TopNOperatorStatus) status; + assertThat(topNStatus.receiveNanos(), greaterThan(0L)); + assertThat(topNStatus.emitNanos(), equalTo(0L)); + assertThat(topNStatus.occupiedRows(), equalTo(4)); + assertThat(topNStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(topNStatus.pagesReceived(), equalTo(1)); + assertThat(topNStatus.pagesEmitted(), equalTo(0)); + assertThat(topNStatus.rowsReceived(), equalTo(10L)); + assertThat(topNStatus.rowsEmitted(), equalTo(0L)); + } + } + private long randomJumboPageBytes() { return rarely() ? between(1, 100) : ByteSizeValue.ofKb(between(1, 1000)).getBytes(); } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java index d06f8ea4ba44a..43612dc985c1f 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java @@ -10,13 +10,17 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.lucene.IndexedByShardId; +import org.elasticsearch.core.RefCounted; import java.util.List; @@ -441,4 +445,68 @@ public void close() { builder.close(); } } + + public static class DocBlockBuilder extends TestBlockBuilder { + private final DocBlock.Builder builder; + private final IndexedByShardId shardRefCounters; + + public DocBlockBuilder(BlockFactory blockFactory, IndexedByShardId shardRefCounters) { + this.shardRefCounters = shardRefCounters; + this.builder = DocBlock.newBlockBuilder(blockFactory, 0); + } + + @Override + public TestBlockBuilder appendObject(Object object) { + var doc = (BlockUtils.Doc) object; + builder.appendDoc(doc.doc()); + builder.appendSegment(doc.segment()); + builder.appendShard(doc.shard()); + return this; + } + + @Override + public TestBlockBuilder appendNull() { + builder.appendNull(); + return this; + } + + @Override + public TestBlockBuilder beginPositionEntry() { + builder.beginPositionEntry(); + return this; + } + + @Override + public TestBlockBuilder endPositionEntry() { + builder.endPositionEntry(); + return this; + } + + @Override + public Block.Builder copyFrom(Block block, int beginInclusive, int endExclusive) { + builder.copyFrom(block, beginInclusive, endExclusive); + return this; + } + + @Override + public Block.Builder mvOrdering(Block.MvOrdering mvOrdering) { + builder.mvOrdering(mvOrdering); + return this; + } + + @Override + public long estimatedBytes() { + return builder.estimatedBytes(); + } + + @Override + public Block build() { + return builder.shardRefCounters(shardRefCounters).build(); + } + + @Override + public void close() { + builder.close(); + } + } } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java new file mode 100644 index 0000000000000..42a022ce04a7b --- /dev/null +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.test; + +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.test.operator.blocksource.AbstractBlockSourceOperator; + +import java.util.List; + +public abstract class TypedAbstractBlockSourceBuilder extends AbstractBlockSourceOperator { + protected TypedAbstractBlockSourceBuilder(BlockFactory blockFactory, int maxPagePositions) { + super(blockFactory, maxPagePositions); + } + + public abstract List elementTypes(); +} diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java index 33de84953e690..eab815f4bf1df 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.test.TestBlockBuilder; +import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; import org.elasticsearch.core.Releasables; import java.util.List; @@ -22,7 +23,7 @@ /** * A source operator whose output is rows specified as a list {@link List} values. */ -public class ListRowsBlockSourceOperator extends AbstractBlockSourceOperator { +public class ListRowsBlockSourceOperator extends TypedAbstractBlockSourceBuilder { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List types; @@ -39,7 +40,7 @@ protected Page createPage(int positionOffset, int length) { TestBlockBuilder[] blocks = new TestBlockBuilder[types.size()]; try { for (int b = 0; b < blocks.length; b++) { - blocks[b] = TestBlockBuilder.builderOf(blockFactory, types.get(b)); + blocks[b] = getTestBlockBuilder(b); } for (int i = 0; i < length; i++) { List row = rows.get(positionOffset + i); @@ -69,11 +70,16 @@ protected Page createPage(int positionOffset, int length) { } } + protected TestBlockBuilder getTestBlockBuilder(int b) { + return TestBlockBuilder.builderOf(blockFactory, types.get(b)); + } + @Override protected int remaining() { return rows.size() - currentPosition; } + @Override public List elementTypes() { return types; } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java index f15fec7b0e8d8..e13f73785a2bd 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; import org.elasticsearch.core.Tuple; import java.util.List; @@ -19,7 +20,7 @@ * A source operator whose output is the given tuple values. This operator produces pages * with two Blocks. The returned pages preserve the order of values as given in the in initial list. */ -public abstract class TupleAbstractBlockSourceOperator extends AbstractBlockSourceOperator { +public abstract class TupleAbstractBlockSourceOperator extends TypedAbstractBlockSourceBuilder { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List> values; @@ -86,6 +87,7 @@ protected int remaining() { return values.size() - currentPosition; } + @Override public List elementTypes() { return List.of(firstElementType, secondElementType); } From 467ffbb79e4f9f9441dfefc5f1df5358d4fab4d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Mar 2026 15:05:42 +0100 Subject: [PATCH 02/22] Format --- .../elasticsearch/compute/operator/topn/TopNOperatorTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 181c581a98ba1..56602efa7e240 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -103,10 +103,10 @@ import static org.elasticsearch.test.ListMatcher.matchesList; import static org.elasticsearch.test.MapMatcher.assertMap; import static org.hamcrest.Matchers.both; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; From 75f84c76c10c355a81e341f442c97c619becd770 Mon Sep 17 00:00:00 2001 From: ncordon Date: Wed, 4 Mar 2026 12:36:00 +0100 Subject: [PATCH 03/22] Fixes failing tests --- .../elasticsearch/compute/operator/topn/GroupedQueueTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java index bead85e16c895..91347f472e48b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -199,7 +199,7 @@ private GroupedRow createRow(CircuitBreaker breaker, int groupKey, int sortKey) return row; } - private static void assertRowValues(GroupedRow row, int expectedGroupKey, int expectedSortKey, int expectedValue) { + private static void assertRowValues(GroupedRow row, long expectedGroupKey, int expectedSortKey, int expectedValue) { assertThat(row.groupId, equalTo(expectedGroupKey)); BytesRef keys = row.keys().bytesRefView(); From ee4515144782cf23c09091d4c45d55a9ecf802cd Mon Sep 17 00:00:00 2001 From: ncordon Date: Wed, 4 Mar 2026 13:36:05 +0100 Subject: [PATCH 04/22] Shares more code --- .../compute/operator/topn/GroupedQueue.java | 108 +------- .../operator/topn/GroupedRowFiller.java | 90 ------ .../operator/topn/GroupedTopNOperator.java | 25 +- .../compute/operator/topn/TopNOperator.java | 262 +++--------------- .../compute/operator/topn/TopNQueue.java | 123 ++++++++ .../topn/{GroupedRow.java => TopNRow.java} | 105 ++++--- .../operator/topn/GroupedQueueTests.java | 65 +++-- .../operator/topn/GroupedRowTests.java | 26 +- .../operator/topn/TopNOperatorTests.java | 24 +- .../compute/operator/topn/TopNRowTests.java | 14 +- 10 files changed, 318 insertions(+), 524 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java rename x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/{GroupedRow.java => TopNRow.java} (51%) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java index d0b0718bf5ceb..b1ee79748a0be 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java @@ -8,8 +8,6 @@ package org.elasticsearch.compute.operator.topn; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.PriorityQueue; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.ObjectArray; @@ -22,7 +20,7 @@ import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; /** - * A queue that maintains a separate per-group priority queue, indexed by integer group IDs + * A queue that maintains a separate {@link TopNQueue} per group, indexed by integer group IDs * assigned by a {@link org.elasticsearch.compute.aggregation.blockhash.BlockHash}. * Uses a {@link BigArrays}-backed {@link ObjectArray} for better performance and circuit * breaker integration. @@ -33,7 +31,7 @@ class GroupedQueue implements Accountable, Releasable { private final CircuitBreaker breaker; private final BigArrays bigArrays; private final int topCount; - private ObjectArray queues; + private ObjectArray queues; GroupedQueue(CircuitBreaker breaker, BigArrays bigArrays, int topCount) { this.breaker = breaker; @@ -50,7 +48,7 @@ public String toString() { int size() { int totalSize = 0; for (long i = 0; i < queues.size(); i++) { - PerGroupQueue queue = queues.get(i); + TopNQueue queue = queues.get(i); if (queue != null) { totalSize += queue.size(); } @@ -59,22 +57,22 @@ int size() { } /** - * Attempts to add the row to the appropriate per-group queue based on {@link GroupedRow#groupId}. + * Attempts to add the row to the per-group queue identified by {@code groupId}. * @return If the row was added and the queue was full, the evicted row. * If the row was added and it wasn't full, {@code null}. * If the row wasn't added, the input row. */ - GroupedRow addRow(GroupedRow row) { - return getOrCreateQueue(row.groupId).addRow(row); + TopNRow addRow(long groupId, TopNRow row) { + return getOrCreateQueue(groupId).addRow(row); } - private PerGroupQueue getOrCreateQueue(long groupId) { + private TopNQueue getOrCreateQueue(long groupId) { if (groupId >= queues.size()) { queues = bigArrays.grow(queues, groupId + 1); } - PerGroupQueue queue = queues.get(groupId); + TopNQueue queue = queues.get(groupId); if (queue == null) { - queue = PerGroupQueue.build(breaker, topCount); + queue = TopNQueue.build(breaker, topCount); queues.set(groupId, queue); } return queue; @@ -85,10 +83,10 @@ private PerGroupQueue getOrCreateQueue(long groupId) { * For an ascending order, the first element will be the min element (or last in the * priority queue), and vice versa. */ - List popAll() { - List allRows = new ArrayList<>(size()); + List popAll() { + List allRows = new ArrayList<>(size()); for (long i = 0; i < queues.size(); i++) { - PerGroupQueue queue = queues.get(i); + TopNQueue queue = queues.get(i); if (queue != null) { queue.popAllInto(allRows); queue.close(); @@ -105,7 +103,7 @@ public long ramBytesUsed() { if (queues != null) { total += queues.ramBytesUsed(); for (long i = 0; i < queues.size(); i++) { - PerGroupQueue queue = queues.get(i); + TopNQueue queue = queues.get(i); if (queue != null) { total += queue.ramBytesUsed(); } @@ -119,7 +117,7 @@ public void close() { Releasables.close(() -> { if (queues != null) { for (long i = 0; i < queues.size(); i++) { - PerGroupQueue queue = queues.get(i); + TopNQueue queue = queues.get(i); if (queue != null) { queue.close(); queues.set(i, null); @@ -128,82 +126,4 @@ public void close() { } }, queues); } - - /** - * A single-group priority queue backed by Lucene's PriorityQueue. - */ - static final class PerGroupQueue extends PriorityQueue implements Accountable, Releasable { - private static final long SHALLOW_SIZE = shallowSizeOfInstance(PerGroupQueue.class); - - private final CircuitBreaker breaker; - private final int topCount; - - private PerGroupQueue(CircuitBreaker breaker, int topCount) { - super(topCount); - this.topCount = topCount; - this.breaker = breaker; - } - - static PerGroupQueue build(CircuitBreaker breaker, int topCount) { - breaker.addEstimateBytesAndMaybeBreak(sizeOf(topCount), "topn"); - return new PerGroupQueue(breaker, topCount); - } - - @Override - protected boolean lessThan(GroupedRow lhs, GroupedRow rhs) { - return lhs.compareTo(rhs) < 0; - } - - GroupedRow addRow(GroupedRow row) { - if (size() < topCount) { - add(row); - return null; - } else if (lessThan(top(), row)) { - GroupedRow evicted = top(); - updateTop(row); - return evicted; - } - return row; - } - - void popAllInto(List target) { - while (size() > 0) { - target.add(pop()); - } - } - - @Override - public long ramBytesUsed() { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) - ); - for (GroupedRow r : this) { - total += r == null ? 0 : r.ramBytesUsed(); - } - return total; - } - - @Override - public void close() { - Releasables.close(() -> { - var heapArray = getHeapArray(); - for (int i = 0; i < heapArray.length; i++) { - GroupedRow row = (GroupedRow) heapArray[i]; - if (row != null) { - row.close(); - heapArray[i] = null; - } - } - }, () -> breaker.addWithoutBreaking(-sizeOf(topCount))); - } - - private static long sizeOf(int topCount) { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * (topCount + 1L) - ); - return total; - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java deleted file mode 100644 index c8c53bbb156ac..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRowFiller.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.operator.topn; - -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; - -import java.util.List; - -/** - * Fills {@link GroupedRow}s from page data for grouped top-N. Handles both sort-key encoding - * and value extraction. The group ID is set directly by the caller from the BlockHash callback. - */ -final class GroupedRowFiller { - private final ValueExtractor[] valueExtractors; - private final KeyExtractor[] sortKeyExtractors; - - private int keyPreAllocSize = 0; - private int valuePreAllocSize = 0; - - GroupedRowFiller( - List elementTypes, - List encoders, - List sortOrders, - boolean[] channelInKey, - Page page - ) { - valueExtractors = new ValueExtractor[page.getBlockCount()]; - for (int b = 0; b < valueExtractors.length; b++) { - valueExtractors[b] = ValueExtractor.extractorFor( - elementTypes.get(b), - encoders.get(b).toUnsortable(), - channelInKey[b], - page.getBlock(b) - ); - } - sortKeyExtractors = new KeyExtractor[sortOrders.size()]; - for (int k = 0; k < sortKeyExtractors.length; k++) { - TopNOperator.SortOrder so = sortOrders.get(k); - sortKeyExtractors[k] = KeyExtractor.extractorFor( - elementTypes.get(so.channel()), - encoders.get(so.channel()), - so.asc(), - so.nul(), - so.nonNul(), - page.getBlock(so.channel()) - ); - } - } - - int preAllocatedKeysSize() { - return keyPreAllocSize; - } - - int preAllocatedValueSize() { - return valuePreAllocSize; - } - - void writeSortKey(int position, GroupedRow row) { - for (KeyExtractor keyExtractor : sortKeyExtractors) { - keyExtractor.writeKey(row.keys(), position); - } - keyPreAllocSize = newPreAllocSize(row.keys(), keyPreAllocSize); - } - - void writeValues(int position, GroupedRow row) { - for (ValueExtractor e : valueExtractors) { - var refCounted = e.getRefCountedForShard(position); - if (refCounted != null) { - row.setShardRefCounted(refCounted); - } - e.writeValue(row.values(), position); - } - valuePreAllocSize = newPreAllocSize(row.values(), valuePreAllocSize); - } - - /** - * Pre-allocation size heuristic: use the larger of the current builder length and half - * the previous pre-alloc size, so the size decays after a single unusually large row. - */ - private static int newPreAllocSize(BreakingBytesRefBuilder builder, int sparePreAllocSize) { - return Math.max(builder.length(), sparePreAllocSize / 2); - } -} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java index 4a69f5d337430..6e57b3c118285 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -108,7 +108,7 @@ public String describe() { private BytesRefHashTable keysHash; private GroupedQueue inputQueue; - private GroupedRow spare; + private TopNRow spare; private ReleasableIterator output; @@ -172,7 +172,7 @@ public void addInput(Page page) { if (this.topCount <= 0) { return; } - GroupedRowFiller rowFiller = new GroupedRowFiller(elementTypes, encoders, sortOrders, channelInKey, page); + TopNOperator.RowFiller rowFiller = new TopNOperator.RowFiller(elementTypes, encoders, sortOrders, channelInKey, page); for (int pos = 0; pos < page.getPositionCount(); pos++) { BytesRef key = keyEncoder.encode(page, pos); long hashOrd = keysHash.add(key); @@ -187,16 +187,15 @@ public void addInput(Page page) { } } - private void processRow(GroupedRowFiller rowFiller, int position, long groupId) { + private void processRow(TopNOperator.RowFiller rowFiller, int position, long groupId) { if (spare == null) { - spare = new GroupedRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); + spare = new TopNRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); } else { spare.clear(); } - spare.groupId = groupId; - rowFiller.writeSortKey(position, spare); + rowFiller.writeKey(position, spare); - var nextSpare = inputQueue.addRow(spare); + var nextSpare = inputQueue.addRow(groupId, spare); if (nextSpare != spare) { var insertedRow = spare; spare = nextSpare; @@ -309,7 +308,7 @@ private ReleasableIterator buildResult() { return ReleasableIterator.empty(); } - List rows = inputQueue.popAll(); + List rows = inputQueue.popAll(); inputQueue.close(); keysHash.close(); inputQueue = null; @@ -318,10 +317,10 @@ private ReleasableIterator buildResult() { } private class Result implements ReleasableIterator { - private final List rows; + private final List rows; private int r; - private Result(List rows) { + private Result(List rows) { this.rows = rows; } @@ -344,9 +343,9 @@ public Page next() { } int rEnd = r + size; while (r < rEnd) { - try (GroupedRow row = rows.set(r++, null)) { - readKeys(builders, row.keys().bytesRefView()); - readValues(builders, row.values().bytesRefView()); + try (TopNRow row = rows.set(r++, null)) { + readKeys(builders, row.keys.bytesRefView()); + readValues(builders, row.values.bytesRefView()); } if (totalSize(builders) > jumboPageBytes) { break; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java index 2fd7231b7ee36..ecf4d618af1b4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java @@ -9,27 +9,21 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Objects; /** * An operator that sorts "rows" of values by encoding the values to sort on, as bytes (using BytesRef). Each data type is encoded @@ -52,123 +46,16 @@ public enum InputOrdering { } /** - * A single top "row". Implements {@link Comparable} and {@link Row#equals} comparing - * the sort keys. + * Fills {@link TopNRow}s from page data. Handles both sort-key encoding and value + * extraction, and tracks pre-allocation sizes for key and value buffers. */ - static final class Row implements Accountable, Comparable, Releasable { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Row.class); - - private final CircuitBreaker breaker; - - /** - * The sort keys, encoded into bytes so we can sort by calling {@link Arrays#compareUnsigned}. - */ - final BreakingBytesRefBuilder keys; - - /** - * Values to reconstruct the row. Sort of. When we reconstruct the row we read - * from both the {@link #keys} and the {@link #values}. So this only contains - * what is required to reconstruct the row that isn't already stored in {@link #values}. - */ - final BreakingBytesRefBuilder values; - - /** - * Reference counter for the shard this row belongs to, used for rows containing a {@link DocVector} to ensure that the shard - * context before we build the final result. - */ - @Nullable - RefCounted shardRefCounter; - - Row(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { - breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "topn"); - this.breaker = breaker; - boolean success = false; - try { - keys = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedKeysSize); - values = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedValueSize); - success = true; - } finally { - if (success == false) { - close(); - } - } - } - - @Override - public long ramBytesUsed() { - return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); - } - - @Override - public void close() { - clearRefCounters(); - Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); - } - - public void clearRefCounters() { - if (shardRefCounter != null) { - shardRefCounter.decRef(); - } - shardRefCounter = null; - } - - void setShardRefCounted(RefCounted shardRefCounted) { - if (this.shardRefCounter != null) { - this.shardRefCounter.decRef(); - } - this.shardRefCounter = shardRefCounted; - this.shardRefCounter.mustIncRef(); - } - - @Override - public int compareTo(Row rhs) { - // TODO if we fill the trailing bytes with 0 we could do a comparison on the entire array - // When Nik measured this it was marginally faster. But it's worth a bit of research. - return -keys.bytesRefView().compareTo(rhs.keys.bytesRefView()); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) { - return false; - } - ; - Row row = (Row) o; - return keys.bytesRefView().equals(row.keys.bytesRefView()); - } - - @Override - public int hashCode() { - return Objects.hashCode(keys); - } - - @Override - public String toString() { - StringBuilder b = new StringBuilder("Row[key="); - b.append(keys.bytesRefView()); - b.append(", values="); - - if (values.length() < 100) { - b.append(values.bytesRefView()); - } else { - b.append('['); - assert values.bytesRefView().offset == 0; - for (int i = 0; i < 100; i++) { - if (i != 0) { - b.append(" "); - } - b.append(Integer.toHexString(values.bytesRefView().bytes[i] & 255)); - } - b.append("..."); - } - return b.append("]").toString(); - } - } - static final class RowFiller { private final ValueExtractor[] valueExtractors; private final KeyExtractor[] keyExtractors; + private int keyPreAllocSize = 0; + private int valuePreAllocSize = 0; + RowFiller( List elementTypes, List encoders, @@ -199,13 +86,22 @@ static final class RowFiller { } } - void writeKey(int position, Row row) { + int preAllocatedKeysSize() { + return keyPreAllocSize; + } + + int preAllocatedValueSize() { + return valuePreAllocSize; + } + + void writeKey(int position, TopNRow row) { for (KeyExtractor keyExtractor : keyExtractors) { keyExtractor.writeKey(row.keys, position); } + keyPreAllocSize = newPreAllocSize(row.keys, keyPreAllocSize); } - void writeValues(int position, Row destination) { + void writeValues(int position, TopNRow destination) { for (ValueExtractor e : valueExtractors) { var refCounted = e.getRefCountedForShard(position); if (refCounted != null) { @@ -213,6 +109,15 @@ void writeValues(int position, Row destination) { } e.writeValue(destination.values, position); } + valuePreAllocSize = newPreAllocSize(destination.values, valuePreAllocSize); + } + + /** + * Pre-allocation size heuristic: use the larger of the current builder length and half + * the previous pre-alloc size, so the size decays after a single unusually large row. + */ + private static int newPreAllocSize(BreakingBytesRefBuilder builder, int sparePreAllocSize) { + return Math.max(builder.length(), sparePreAllocSize / 2); } } @@ -308,10 +213,8 @@ public String describe() { */ private int minCompetitiveUpdates; - private Queue inputQueue; - private Row spare; - private int spareValuesPreAllocSize = 0; - private int spareKeysPreAllocSize = 0; + private TopNQueue inputQueue; + private TopNRow spare; private ReleasableIterator output; @@ -352,11 +255,11 @@ public TopNOperator( InputOrdering inputOrdering, @Nullable SharedMinCompetitive.Supplier minCompetitiveSupplier ) { - Queue inputQueue = null; + TopNQueue inputQueue = null; SharedMinCompetitive minCompetitive = null; boolean success = false; try { - inputQueue = Queue.build(breaker, topCount); + inputQueue = TopNQueue.build(breaker, topCount); minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); success = true; } finally { @@ -407,32 +310,25 @@ public void addInput(Page page) { for (int i = 0; i < page.getPositionCount(); i++) { if (spare == null) { - spare = new Row(breaker, spareKeysPreAllocSize, spareValuesPreAllocSize); + spare = new TopNRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); } else { - spare.keys.clear(); - spare.values.clear(); - spare.clearRefCounters(); + spare.clear(); } rowFiller.writeKey(i, spare); - // When rows are very long, appending the values one by one can lead to lots of allocations. - // To avoid this, pre-allocate at least as much size as in the last seen row. - // Let the pre-allocation size decay in case we only have 1 huge row and smaller rows otherwise. - spareKeysPreAllocSize = Math.max(spare.keys.length(), spareKeysPreAllocSize / 2); - - // This is `inputQueue.insertWithOverflow` with followed by filling in the value only if we inserted. + // This is `inputQueue.insertWithOverflow` followed by filling in the value only if we inserted. + // We must write values BEFORE modifying the queue so that if writeValues throws (e.g. circuit + // breaker), spare is not left in both the queue and the spare field (which would double-close). if (inputQueue.size() < inputQueue.topCount) { // Heap not yet full, just add elements rowFiller.writeValues(i, spare); - spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2); inputQueue.add(spare); spare = null; modified = true; } else if (inputQueue.lessThan(inputQueue.top(), spare)) { - // Heap full AND this node fit in it. - Row nextSpare = inputQueue.top(); + // Heap full AND this node fits in it. + TopNRow nextSpare = inputQueue.top(); rowFiller.writeValues(i, spare); - spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2); inputQueue.updateTop(spare); spare = nextSpare; modified = true; @@ -579,80 +475,6 @@ public String toString() { + "]"; } - private static class Queue extends PriorityQueue implements Accountable, Releasable { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Queue.class); - private final CircuitBreaker breaker; - private final int topCount; - - /** - * Track memory usage in the breaker then build the {@link Queue}. - */ - static Queue build(CircuitBreaker breaker, int topCount) { - breaker.addEstimateBytesAndMaybeBreak(Queue.sizeOf(topCount), "esql engine topn"); - return new Queue(breaker, topCount); - } - - private Queue(CircuitBreaker breaker, int topCount) { - super(topCount); - this.breaker = breaker; - this.topCount = topCount; - } - - @Override - protected boolean lessThan(Row lhs, Row rhs) { - return lhs.compareTo(rhs) < 0; - } - - @Override - public String toString() { - return size() + "/" + topCount; - } - - @Override - public long ramBytesUsed() { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) - ); - for (Row r : this) { - total += r == null ? 0 : r.ramBytesUsed(); - } - return total; - } - - @Override - public void close() { - Releasables.close( - /* - * Release all entries in the topn, nulling references to each row after closing them - * so they can be GC immediately. Without this nulling very large heaps can race with - * the circuit breaker itself. With this we're still racing, but we're only racing a - * single row at a time. And single rows can only be so large. And we have enough slop - * to live with being inaccurate by one row. - */ - () -> { - for (int i = 0; i < getHeapArray().length; i++) { - Row row = (Row) getHeapArray()[i]; - if (row != null) { - row.close(); - getHeapArray()[i] = null; - } - } - }, - // Release the array itself - () -> breaker.addWithoutBreaking(-Queue.sizeOf(topCount)) - ); - } - - private static long sizeOf(int topCount) { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) - ); - return total; - } - } - /** * Build the result iterator. Moves all rows from the {@link #inputQueue} and * {@link #close}s it. @@ -668,10 +490,8 @@ private ReleasableIterator buildResult() { return ReleasableIterator.empty(); } - List rows = new ArrayList<>(inputQueue.size()); - while (inputQueue.size() > 0) { - rows.add(inputQueue.pop()); - } + List rows = new ArrayList<>(inputQueue.size()); + inputQueue.popAllInto(rows); Collections.reverse(rows); inputQueue.close(); inputQueue = null; @@ -679,10 +499,10 @@ private ReleasableIterator buildResult() { } private class Result implements ReleasableIterator { - private final List rows; + private final List rows; private int r; - private Result(List rows) { + private Result(List rows) { this.rows = rows; } @@ -705,7 +525,7 @@ public Page next() { } int rEnd = r + size; while (r < rEnd) { - try (Row row = rows.set(r++, null)) { + try (TopNRow row = rows.set(r++, null)) { readKeys(builders, row.keys.bytesRefView()); readValues(builders, row.values.bytesRefView()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java new file mode 100644 index 0000000000000..8bcda7d16e2c8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +/** + * A bounded min-heap of {@link TopNRow}s used to find the top-N rows by sort key. + * Used both by {@link TopNOperator} (one global queue) and by {@link GroupedQueue} + * (one queue per group). + */ +class TopNQueue extends PriorityQueue implements Accountable, Releasable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNQueue.class); + + private final CircuitBreaker breaker; + final int topCount; + + /** + * Track memory usage in the breaker then build the {@link TopNQueue}. + */ + static TopNQueue build(CircuitBreaker breaker, int topCount) { + breaker.addEstimateBytesAndMaybeBreak(sizeOf(topCount), "esql engine topn"); + return new TopNQueue(breaker, topCount); + } + + private TopNQueue(CircuitBreaker breaker, int topCount) { + super(topCount); + this.breaker = breaker; + this.topCount = topCount; + } + + @Override + protected boolean lessThan(TopNRow lhs, TopNRow rhs) { + return lhs.compareTo(rhs) < 0; + } + + /** + * Attempts to insert a row into the queue. + * @return {@code null} if the row was inserted into a non-full queue; + * the evicted row if the row replaced the current top; + * the input row itself if it was rejected (worse than all in the queue). + */ + TopNRow addRow(TopNRow row) { + if (size() < topCount) { + add(row); + return null; + } else if (lessThan(top(), row)) { + TopNRow evicted = top(); + updateTop(row); + return evicted; + } + return row; + } + + /** + * Drains all rows from this queue into the given list. + */ + void popAllInto(List target) { + while (size() > 0) { + target.add(pop()); + } + } + + @Override + public String toString() { + return size() + "/" + topCount; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) + ); + for (TopNRow r : this) { + total += r == null ? 0 : r.ramBytesUsed(); + } + return total; + } + + @Override + public void close() { + Releasables.close( + /* + * Release all entries in the topn, nulling references to each row after closing them + * so they can be GC immediately. Without this nulling very large heaps can race with + * the circuit breaker itself. With this we're still racing, but we're only racing a + * single row at a time. And single rows can only be so large. And we have enough slop + * to live with being inaccurate by one row. + */ + () -> { + for (int i = 0; i < getHeapArray().length; i++) { + TopNRow row = (TopNRow) getHeapArray()[i]; + if (row != null) { + row.close(); + getHeapArray()[i] = null; + } + } + }, + () -> breaker.addWithoutBreaking(-sizeOf(topCount)) + ); + } + + static long sizeOf(int topCount) { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) + ); + return total; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java similarity index 51% rename from x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java rename to x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java index 8c139f21c00ed..779c6f9fafd7f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedRow.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java @@ -17,42 +17,38 @@ import org.elasticsearch.core.Releasables; import java.util.Arrays; +import java.util.Objects; /** - * A row that belongs to a group, identified by an integer group ID. - * Stores encoded sort keys and values for a single row within a grouped top-N operation. + * A single row in a top-N operation. Stores encoded sort keys and values. + * Implements {@link Comparable} and {@link #equals} comparing the sort keys. */ -final class GroupedRow implements Accountable, Comparable, Releasable { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedRow.class); +final class TopNRow implements Accountable, Comparable, Releasable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNRow.class); private final CircuitBreaker breaker; /** * The sort keys, encoded into bytes so we can sort by calling {@link Arrays#compareUnsigned}. */ - private final BreakingBytesRefBuilder keys; + final BreakingBytesRefBuilder keys; /** - * Values to reconstruct the row. When we reconstruct the row we read from both the - * {@link #keys} and the {@link #values}. So this only contains what is required to - * reconstruct the row that isn't already stored in {@link #keys}. + * Values to reconstruct the row. When we reconstruct the row we read + * from both the {@link #keys} and the {@link #values}. So this only contains + * what is required to reconstruct the row that isn't already stored in {@link #keys}. */ - private final BreakingBytesRefBuilder values; + final BreakingBytesRefBuilder values; /** - * Reference counter for the shard this row belongs to, used for rows containing a DocVector - * to ensure the shard context lives until we build the final result. + * Reference counter for the shard this row belongs to, used for rows containing a + * DocVector to ensure the shard context lives until we build the final result. */ @Nullable - private RefCounted shardRefCounter; + RefCounted shardRefCounter; - /** - * The group ID this row belongs to. - */ - long groupId = -1; - - GroupedRow(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { - breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "GroupedRow"); + TopNRow(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { + breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "topn"); this.breaker = breaker; boolean success = false; try { @@ -66,49 +62,78 @@ final class GroupedRow implements Accountable, Comparable, Releasabl } } - BreakingBytesRefBuilder keys() { - return keys; - } - - BreakingBytesRefBuilder values() { - return values; + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); } - void setShardRefCounted(RefCounted shardRefCounted) { - if (this.shardRefCounter != null) { - this.shardRefCounter.decRef(); - } - this.shardRefCounter = shardRefCounted; - this.shardRefCounter.mustIncRef(); + @Override + public void close() { + clearRefCounters(); + Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); } void clear() { keys.clear(); values.clear(); clearRefCounters(); - groupId = -1; } - private void clearRefCounters() { + void clearRefCounters() { if (shardRefCounter != null) { shardRefCounter.decRef(); } shardRefCounter = null; } + void setShardRefCounted(RefCounted shardRefCounted) { + if (this.shardRefCounter != null) { + this.shardRefCounter.decRef(); + } + this.shardRefCounter = shardRefCounted; + this.shardRefCounter.mustIncRef(); + } + @Override - public int compareTo(GroupedRow other) { - return -keys.bytesRefView().compareTo(other.keys.bytesRefView()); + public int compareTo(TopNRow rhs) { + // TODO if we fill the trailing bytes with 0 we could do a comparison on the entire array + // When Nik measured this it was marginally faster. But it's worth a bit of research. + return -keys.bytesRefView().compareTo(rhs.keys.bytesRefView()); } @Override - public long ramBytesUsed() { - return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + TopNRow row = (TopNRow) o; + return keys.bytesRefView().equals(row.keys.bytesRefView()); } @Override - public void close() { - clearRefCounters(); - Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); + public int hashCode() { + return Objects.hashCode(keys); + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("TopNRow[key="); + b.append(keys.bytesRefView()); + b.append(", values="); + + if (values.length() < 100) { + b.append(values.bytesRefView()); + } else { + b.append('['); + assert values.bytesRefView().offset == 0; + for (int i = 0; i < 100; i++) { + if (i != 0) { + b.append(" "); + } + b.append(Integer.toHexString(values.bytesRefView().bytes[i] & 255)); + } + b.append("..."); + } + return b.append("]").toString(); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java index 91347f472e48b..4192d8ba268f4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -59,8 +59,8 @@ public void testAddWhenHeapNotFull() { int topCount = 5; try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { for (int i = 0; i < topCount; i++) { - GroupedRow row = createRow(breaker, i % 2, i * 10); - GroupedRow result = queue.addRow(row); + TopNRow row = createRow(breaker, i * 10); + TopNRow result = queue.addRow(i % 2, row); assertThat(result, nullValue()); assertThat(queue.size(), equalTo(i + 1)); } @@ -72,8 +72,8 @@ public void testAddWhenHeapFullAndRowQualifies() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { fillQueueToCapacity(queue, topCount); - try (GroupedRow evicted = queue.addRow(createRow(breaker, 0, 5))) { - assertRowValues(evicted, 0, 20, 40); + try (TopNRow evicted = queue.addRow(0, createRow(breaker, 5))) { + assertRowValues(evicted, 20, 40); } } } @@ -82,8 +82,8 @@ public void testAddWhenHeapFullAndRowDoesNotQualify() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 3)) { addRows(queue, 0, 30, 40, 50); - try (GroupedRow row = createRow(breaker, 0, 60)) { - GroupedRow result = queue.addRow(row); + try (TopNRow row = createRow(breaker, 60)) { + TopNRow result = queue.addRow(0, row); assertThat(result, sameInstance(row)); } } @@ -91,29 +91,29 @@ public void testAddWhenHeapFullAndRowDoesNotQualify() { public void testAddWithDifferentGroupKeys() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 2)) { - assertThat(queue.addRow(createRow(breaker, 0, 10)), nullValue()); - assertThat(queue.addRow(createRow(breaker, 1, 20)), nullValue()); - assertThat(queue.addRow(createRow(breaker, 0, 30)), nullValue()); - assertThat(queue.addRow(createRow(breaker, 1, 40)), nullValue()); + assertThat(queue.addRow(0, createRow(breaker, 10)), nullValue()); + assertThat(queue.addRow(1, createRow(breaker, 20)), nullValue()); + assertThat(queue.addRow(0, createRow(breaker, 30)), nullValue()); + assertThat(queue.addRow(1, createRow(breaker, 40)), nullValue()); assertThat(queue.size(), equalTo(4)); - try (GroupedRow evicted = queue.addRow(createRow(breaker, 0, 5))) { + try (TopNRow evicted = queue.addRow(0, createRow(breaker, 5))) { assertThat(evicted, notNullValue()); - assertRowValues(evicted, 0, 30, 60); + assertRowValues(evicted, 30, 60); } - try (GroupedRow evicted = queue.addRow(createRow(breaker, 1, 15))) { + try (TopNRow evicted = queue.addRow(1, createRow(breaker, 15))) { assertThat(evicted, notNullValue()); - assertRowValues(evicted, 1, 40, 80); + assertRowValues(evicted, 40, 80); } assertThat(queue.size(), equalTo(4)); - try (GroupedRow row = queue.addRow(createRow(breaker, 0, 50))) { + try (TopNRow row = queue.addRow(0, createRow(breaker, 50))) { assertThat(row, notNullValue()); - assertRowValues(row, 0, 50, 100); + assertRowValues(row, 50, 100); } - try (GroupedRow row = queue.addRow(createRow(breaker, 1, 50))) { + try (TopNRow row = queue.addRow(1, createRow(breaker, 50))) { assertThat(row, notNullValue()); - assertRowValues(row, 1, 50, 100); + assertRowValues(row, 50, 100); } assertThat(queue.size(), equalTo(4)); } @@ -177,13 +177,12 @@ public void testPopAllSortedBySortKey() { } } - private GroupedRow createRow(CircuitBreaker breaker, int groupKey, int sortKey) { - IntBlock groupKeyBlock = blockFactory.newIntBlockBuilder(1).appendInt(groupKey).build(); + private TopNRow createRow(CircuitBreaker breaker, int sortKey) { + IntBlock groupKeyBlock = blockFactory.newIntBlockBuilder(1).appendInt(0).build(); IntBlock keyBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey).build(); IntBlock valueBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey * 2).build(); - GroupedRow row = new GroupedRow(breaker, 32, 64); - row.groupId = groupKey; - var filler = new GroupedRowFiller( + TopNRow row = new TopNRow(breaker, 32, 64); + var filler = new TopNOperator.RowFiller( List.of(ElementType.INT, ElementType.INT, ElementType.INT), List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_UNSORTABLE), SORT_ORDERS, @@ -191,7 +190,7 @@ private GroupedRow createRow(CircuitBreaker breaker, int groupKey, int sortKey) new Page(groupKeyBlock, keyBlock, valueBlock) ); try { - filler.writeSortKey(0, row); + filler.writeKey(0, row); filler.writeValues(0, row); } finally { Releasables.close(groupKeyBlock, keyBlock, valueBlock); @@ -199,16 +198,14 @@ private GroupedRow createRow(CircuitBreaker breaker, int groupKey, int sortKey) return row; } - private static void assertRowValues(GroupedRow row, long expectedGroupKey, int expectedSortKey, int expectedValue) { - assertThat(row.groupId, equalTo(expectedGroupKey)); - - BytesRef keys = row.keys().bytesRefView(); + private static void assertRowValues(TopNRow row, int expectedSortKey, int expectedValue) { + BytesRef keys = row.keys.bytesRefView(); assertThat( TopNEncoder.DEFAULT_SORTABLE.decodeInt(new BytesRef(keys.bytes, keys.offset + 1, keys.length - 1)), equalTo(expectedSortKey) ); - BytesRef values = row.values().bytesRefView(); + BytesRef values = row.values.bytesRefView(); BytesRef reader = new BytesRef(values.bytes, values.offset, values.length); assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader); @@ -218,8 +215,8 @@ private static void assertRowValues(GroupedRow row, long expectedGroupKey, int e } private void addRow(GroupedQueue queue, int groupKey, int value) { - GroupedRow row = createRow(breaker, groupKey, value); - Releasables.close(queue.addRow(row)); + TopNRow row = createRow(breaker, value); + Releasables.close(queue.addRow(groupKey, row)); } private void fillQueueToCapacity(GroupedQueue queue, int capacity) { @@ -234,12 +231,12 @@ private void addRows(GroupedQueue queue, int groupKey, int... values) { private static final List SORT_ORDERS = List.of(new TopNOperator.SortOrder(1, true, false)); - private static void assertQueueContents(GroupedQueue queue, List> groupAndSortKeys) { + private void assertQueueContents(GroupedQueue queue, List> groupAndSortKeys) { assertThat(queue.size(), equalTo(groupAndSortKeys.size())); - List actual = queue.popAll(); + List actual = queue.popAll(); for (int i = 0; i < groupAndSortKeys.size(); i++) { Tuple expectedTuple = groupAndSortKeys.get(i); - assertRowValues(actual.get(i), expectedTuple.v1(), expectedTuple.v2(), expectedTuple.v2() * 2); + assertRowValues(actual.get(i), expectedTuple.v2(), expectedTuple.v2() * 2); } Releasables.close(actual); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java index 63b4bf30dd28e..5981e6c90298a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java @@ -25,7 +25,7 @@ public class GroupedRowTests extends ESTestCase { public void testCloseReleasesAllTestsNoPreAllocation() throws Exception { BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); - var row = new GroupedRow(breaker, 0, 0); + var row = new TopNRow(breaker, 0, 0); row.close(); MockBigArrays.ensureAllArraysAreReleased(); assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); @@ -34,39 +34,39 @@ public void testCloseReleasesAllTestsNoPreAllocation() throws Exception { public void testCloseReleasesAllTestsWithPreAllocation() throws Exception { BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); - var row = new GroupedRow(breaker, 16, 32); + var row = new TopNRow(breaker, 16, 32); row.close(); MockBigArrays.ensureAllArraysAreReleased(); assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); } public void testRamBytesUsedEmpty() { - var row = new GroupedRow(breaker, 0, 0); + var row = new TopNRow(breaker, 0, 0); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } public void testRamBytesUsedSmall() { - var row = new GroupedRow(breaker, 0, 0); - row.keys().append(randomByte()); - row.values().append(randomByte()); + var row = new TopNRow(breaker, 0, 0); + row.keys.append(randomByte()); + row.values.append(randomByte()); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } public void testRamBytesUsedBig() { - var row = new GroupedRow(breaker, 0, 0); + var row = new TopNRow(breaker, 0, 0); for (int i = 0; i < 10000; i++) { - row.keys().append(randomByte()); - row.values().append(randomByte()); + row.keys.append(randomByte()); + row.values.append(randomByte()); } assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } public void testRamBytesUsedPreAllocated() { - var row = new GroupedRow(breaker, 64, 128); + var row = new TopNRow(breaker, 64, 128); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } - private long expectedRamBytesUsed(GroupedRow row) { + private long expectedRamBytesUsed(TopNRow row) { var expected = RamUsageTester.ramUsed(row); expected -= RamUsageTester.ramUsed(breaker); expected -= sharedRowBytes(); @@ -78,8 +78,8 @@ private static long sharedRowBytes() { return RamUsageTester.ramUsed("topn"); } - static long undercountedBytesForRow(GroupedRow row) { - return emptyByteArrayOverhead(row.values()); + static long undercountedBytesForRow(TopNRow row) { + return emptyByteArrayOverhead(row.values); } private static long emptyByteArrayOverhead(BreakingBytesRefBuilder builder) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 56602efa7e240..972b7490b1207 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -854,14 +854,14 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder for (int b = 0; b < page.getBlockCount(); b++) { // Non-null identity for (int p = 0; p < page.getPositionCount(); p++) { - TopNOperator.Row row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), page, p); + TopNRow row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), page, p); assertThat(row, equalTo(row)); assertThat(row.compareTo(row), equalTo(0)); } // Null identity for (int p = 0; p < page.getPositionCount(); p++) { - TopNOperator.Row row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), nullPage, p); + TopNRow row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), nullPage, p); assertThat(row, equalTo(row)); assertThat(row.compareTo(row), equalTo(0)); } @@ -869,8 +869,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // nulls first for (int p = 0; p < page.getPositionCount(); p++) { boolean asc = randomBoolean(); - TopNOperator.Row nonNullRow = row(elementType, encoder, b, asc, true, page, p); - TopNOperator.Row nullRow = row(elementType, encoder, b, asc, true, nullPage, p); + TopNRow nonNullRow = row(elementType, encoder, b, asc, true, page, p); + TopNRow nullRow = row(elementType, encoder, b, asc, true, nullPage, p); assertThat(nonNullRow, not(equalTo(nullRow))); assertThat(nonNullRow, lessThan(nullRow)); assertThat(nullRow, greaterThan(nonNullRow)); @@ -879,8 +879,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // nulls last for (int p = 0; p < page.getPositionCount(); p++) { boolean asc = randomBoolean(); - TopNOperator.Row nonNullRow = row(elementType, encoder, b, asc, false, page, p); - TopNOperator.Row nullRow = row(elementType, encoder, b, asc, false, nullPage, p); + TopNRow nonNullRow = row(elementType, encoder, b, asc, false, page, p); + TopNRow nullRow = row(elementType, encoder, b, asc, false, nullPage, p); assertThat(nonNullRow, not(equalTo(nullRow))); assertThat(nonNullRow, greaterThan(nullRow)); assertThat(nullRow, lessThan(nonNullRow)); @@ -889,8 +889,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // ascending { boolean nullsFirst = randomBoolean(); - TopNOperator.Row r1 = row(elementType, encoder, b, true, nullsFirst, page, 0); - TopNOperator.Row r2 = row(elementType, encoder, b, true, nullsFirst, page, 1); + TopNRow r1 = row(elementType, encoder, b, true, nullsFirst, page, 0); + TopNRow r2 = row(elementType, encoder, b, true, nullsFirst, page, 1); assertThat(r1, not(equalTo(r2))); assertThat(r1, greaterThan(r2)); assertThat(r2, lessThan(r1)); @@ -898,8 +898,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // descending { boolean nullsFirst = randomBoolean(); - TopNOperator.Row r1 = row(elementType, encoder, b, false, nullsFirst, page, 0); - TopNOperator.Row r2 = row(elementType, encoder, b, false, nullsFirst, page, 1); + TopNRow r1 = row(elementType, encoder, b, false, nullsFirst, page, 0); + TopNRow r2 = row(elementType, encoder, b, false, nullsFirst, page, 1); assertThat(r1, not(equalTo(r2))); assertThat(r1, lessThan(r2)); assertThat(r2, greaterThan(r1)); @@ -908,7 +908,7 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder page.releaseBlocks(); } - private TopNOperator.Row row( + private TopNRow row( ElementType elementType, TopNEncoder encoder, int channel, @@ -927,7 +927,7 @@ private TopNOperator.Row row( channelInKey, page ); - TopNOperator.Row row = new TopNOperator.Row(nonBreakingBigArrays().breakerService().getBreaker("request"), 0, 0); + TopNRow row = new TopNRow(nonBreakingBigArrays().breakerService().getBreaker("request"), 0, 0); rf.writeKey(position, row); rf.writeValues(position, row); return row; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java index 92cecdf79528c..0fbf656119310 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java @@ -18,12 +18,12 @@ public class TopNRowTests extends ESTestCase { private final CircuitBreaker breaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST); public void testRamBytesUsedEmpty() { - TopNOperator.Row row = new TopNOperator.Row(breaker, 0, 0); + TopNRow row = new TopNRow(breaker, 0, 0); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } public void testRamBytesUsedSmall() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); row.keys.append(randomByte()); row.values.append(randomByte()); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); @@ -35,7 +35,7 @@ public void testRamBytesUsedSmall() { * size estimates from previous rows. */ public void testFromHeapDump1() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 56, 24); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 56, 24); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); // 304 was measured debugging a heap dump and we've since shrunk assertThat(row.ramBytesUsed(), equalTo(240L)); @@ -47,14 +47,14 @@ public void testFromHeapDump1() { * size estimates from previous rows. */ public void testFromHeapDump2() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 1160, 1_153_096); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 1160, 1_153_096); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); // 1,154,464 is measured debugging a heap dump and we've since shrunk assertThat(row.ramBytesUsed(), equalTo(1_154_416L)); } public void testRamBytesUsedBig() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); for (int i = 0; i < 10000; i++) { row.keys.append(randomByte()); row.values.append(randomByte()); @@ -63,11 +63,11 @@ public void testRamBytesUsedBig() { } public void testRamBytesUsedPreAllocated() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 64, 128); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 64, 128); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } - private long expectedRamBytesUsed(TopNOperator.Row row) { + private long expectedRamBytesUsed(TopNRow row) { long expected = RamUsageTester.ramUsed(row); if (row.values.bytes().length == 0) { // We double count the shared empty array for empty rows. This overcounting is *fine*, but throws off the test. From b694a8c17206095801cd1bab0621db3bd03eb140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 5 Mar 2026 16:09:27 +0100 Subject: [PATCH 05/22] Simplified grouped topn test --- .../topn/GroupedTopNOperatorTests.java | 163 +++++------------- 1 file changed, 43 insertions(+), 120 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index 5edc16ce03e02..d3635b55fbf97 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -199,6 +199,41 @@ protected void assertSimpleOutput(List input, List results) { ); } + private List> runGroupedTopN( + List pages, + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys + ) { + DriverContext driverContext = driverContext(); + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(pages.iterator()), + List.of( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + return actual; + } + /** * Tests that the SORTED input ordering optimization short-circuiting addInput() doesn't incorrectly skip rows * belonging to groups not yet populated when another group's row is rejected. @@ -212,8 +247,6 @@ protected void assertSimpleOutput(List input, List results) { * } */ public void testSortedInputWithMultipleGroups() { - int topCount = 1; - int[] groupKeys = new int[] { 1 }; List elementTypes = List.of(ElementType.INT, ElementType.INT); List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE); List sortOrders = List.of(new SortOrder(0, true, true)); @@ -246,31 +279,7 @@ public void testSortedInputWithMultipleGroups() { page2 = new Page(sortCol.build(), groupCol.build()); } - List> actual = new ArrayList<>(); - DriverContext driverContext = driverContext(); - - try ( - Driver driver = TestDriverFactory.create( - driverContext, - new CannedSourceOperator(List.of(page1, page2).iterator()), - List.of( - new GroupedTopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), - topCount, - elementTypes, - encoders, - sortOrders, - groupKeys, - randomPageSize(), - Long.MAX_VALUE - ) - ), - new PageConsumerOperator(p -> readInto(actual, p)) - ) - ) { - new TestDriverRunner().run(driver); - } + List> actual = runGroupedTopN(List.of(page1, page2), 1, elementTypes, encoders, sortOrders, new int[] { 1 }); // 3 groups, each with 1 value, ordered ASC: [1, 3, 4] assertThat(actual.get(0), equalTo(List.of(1, 3, 4))); @@ -278,50 +287,16 @@ public void testSortedInputWithMultipleGroups() { } public void testMultivalueGroupKey() { - DriverContext driverContext = driverContext(); - BlockFactory blockFactory = driverContext.blockFactory(); - - int topCount = 1; - int[] groupKeys = new int[] { 2 }; // group key at channel 2 + BlockFactory blockFactory = driverContext().blockFactory(); List elementTypes = List.of(LONG, LONG, LONG); List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); List sortOrders = List.of(new SortOrder(0, true, false)); Page page = new Page( - BlockUtils.fromList( - blockFactory, - List.of( - // (To keep indentation) - List.of(10L, 100L, List.of(1L, 2L)), - List.of(20L, 200L, 1L), - List.of(30L, 300L, 3L) - ) - ) + BlockUtils.fromList(blockFactory, List.of(List.of(10L, 100L, List.of(1L, 2L)), List.of(20L, 200L, 1L), List.of(30L, 300L, 3L))) ); - List> actual = new ArrayList<>(); - try ( - Driver driver = TestDriverFactory.create( - driverContext, - new CannedSourceOperator(List.of(page).iterator()), - List.of( - new GroupedTopNOperator( - blockFactory, - nonBreakingBigArrays().breakerService().getBreaker("request"), - topCount, - elementTypes, - encoders, - sortOrders, - groupKeys, - randomPageSize(), - Long.MAX_VALUE - ) - ), - new PageConsumerOperator(p -> readInto(actual, p)) - ) - ) { - new TestDriverRunner().run(driver); - } + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2 }); // List semantics: [1,2] is one group, 1 is another, 3 is another. Sorted ASC by sort key. assertThat(actual.get(0), equalTo(List.of(10L, 20L, 30L))); // Sort key @@ -330,40 +305,14 @@ public void testMultivalueGroupKey() { } public void testMultivalueGroupKeyDuplicateWinner() { - DriverContext driverContext = driverContext(); - BlockFactory bf = driverContext.blockFactory(); - - int topCount = 1; - int[] groupKeys = new int[] { 2 }; + BlockFactory bf = driverContext().blockFactory(); List elementTypes = List.of(LONG, LONG, LONG); List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); List sortOrders = List.of(new SortOrder(0, true, false)); Page page = new Page(BlockUtils.fromList(bf, List.of(List.of(5L, 50L, List.of(1L, 2L)), List.of(10L, 100L, 1L)))); - List> actual = new ArrayList<>(); - try ( - Driver driver = TestDriverFactory.create( - driverContext, - new CannedSourceOperator(List.of(page).iterator()), - List.of( - new GroupedTopNOperator( - bf, - nonBreakingBigArrays().breakerService().getBreaker("request"), - topCount, - elementTypes, - encoders, - sortOrders, - groupKeys, - randomPageSize(), - Long.MAX_VALUE - ) - ), - new PageConsumerOperator(p -> readInto(actual, p)) - ) - ) { - new TestDriverRunner().run(driver); - } + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2 }); // List semantics: [1,2] is one group, 1 is another. Sorted ASC by sort key. assertThat(actual.get(0), equalTo(List.of(5L, 10L))); // Sort key @@ -377,11 +326,7 @@ public void testMultivalueGroupKeyDuplicateWinner() { * No cartesian product expansion occurs. */ public void testMultipleMultivalueGroupKeys() { - DriverContext driverContext = driverContext(); - BlockFactory bf = driverContext.blockFactory(); - - int topCount = 1; - int[] groupKeys = new int[] { 2, 3 }; // two group keys at channels 2 and 3 + BlockFactory bf = driverContext().blockFactory(); List elementTypes = List.of(LONG, LONG, LONG, LONG); List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); List sortOrders = List.of(new SortOrder(0, true, false)); @@ -393,29 +338,7 @@ public void testMultipleMultivalueGroupKeys() { ) ); - List> actual = new ArrayList<>(); - try ( - Driver driver = TestDriverFactory.create( - driverContext, - new CannedSourceOperator(List.of(page).iterator()), - List.of( - new GroupedTopNOperator( - bf, - nonBreakingBigArrays().breakerService().getBreaker("request"), - topCount, - elementTypes, - encoders, - sortOrders, - groupKeys, - randomPageSize(), - Long.MAX_VALUE - ) - ), - new PageConsumerOperator(p -> readInto(actual, p)) - ) - ) { - new TestDriverRunner().run(driver); - } + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2, 3 }); // List semantics: 3 distinct groups, each with 1 row, sorted ASC by sort key assertThat(actual.get(0), equalTo(List.of(5L, 10L, 15L))); // Sort key From aa96d78b0efbcbd555c565d31dcaa017a529c223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 5 Mar 2026 16:57:27 +0100 Subject: [PATCH 06/22] Make static test methods non-static --- .../compute/operator/topn/GroupedTopNOperatorTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index d3635b55fbf97..0e178b927db51 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -510,13 +510,13 @@ private static Comparator compareValues(SortOrder order) { @SuppressWarnings("unchecked") private static final Comparator CASTING_COMPARATOR = (o1, o2) -> ((Comparable) o1).compareTo(o2); - private static List> computeTopN(List> inputValues, int limit, boolean ascendingOrder) { + private List> computeTopN(List> inputValues, int limit, boolean ascendingOrder) { return computeTopN(inputValues.stream().map(e -> Arrays.asList(e.v1(), e.v2())).toList(), 1, 0, limit, ascendingOrder).stream() .map(l -> Tuple.tuple((Long) l.get(0), (Long) l.get(1))) .toList(); } - private static List> computeTopN( + private List> computeTopN( List> inputValues, int groupChannel, int sortChannel, @@ -529,7 +529,7 @@ private static List> computeTopN( singleValueInput.add(rowAsObject); } List sortOrders = List.of(new SortOrder(sortChannel, ascendingOrder, false)); - return new GroupedTopNOperatorTests().computeTopN(singleValueInput, List.of(groupChannel), sortOrders, limit); + return computeTopN(singleValueInput, List.of(groupChannel), sortOrders, limit); } private List> computeTopN( From af440c6cd5727e5c0a219e29880df5c017a86d44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 5 Mar 2026 17:05:33 +0100 Subject: [PATCH 07/22] Initial GroupedTOpNOperator benchmarks --- .../_nightly/esql/GroupedTopNBenchmark.java | 295 ++++++++++++++++++ .../esql/GroupedTopNBenchmarkTests.java | 18 ++ 2 files changed, 313 insertions(+) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java create mode 100644 benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmarkTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java new file mode 100644 index 0000000000000..f8baf3f85f35a --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java @@ -0,0 +1,295 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark._nightly.esql; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.benchmark.Utils; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.topn.GroupedTopNOperator; +import org.elasticsearch.compute.operator.topn.TopNEncoder; +import org.elasticsearch.compute.operator.topn.TopNOperator; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@Warmup(iterations = 5) +@Measurement(iterations = 7) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Fork(1) +public class GroupedTopNBenchmark { + + private static final BlockFactory blockFactory = BlockFactory.builder(BigArrays.NON_RECYCLING_INSTANCE) + .breaker(new NoopCircuitBreaker("none")) + .build(); + + private static final int BLOCK_LENGTH = 4 * 1024; + private static final int NUM_PAGES = 1024; + private static final int SELF_TEST_PAGES = 16; + + private static final String LONGS = "longs"; + private static final String INTS = "ints"; + private static final String DOUBLES = "doubles"; + private static final String BOOLEANS = "booleans"; + private static final String BYTES_REFS = "bytes_refs"; + + private static final String ASC = "_asc"; + private static final String DESC = "_desc"; + + private static final String AND = "_and_"; + + static { + Utils.configureBenchmarkLogging(); + // Smoke test all the expected values and force loading subclasses more like prod + selfTest(); + } + + static void selfTest() { + try { + for (String data : GroupedTopNBenchmark.class.getField("data").getAnnotationsByType(Param.class)[0].value()) { + for (String topCount : GroupedTopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value()) { + for (String groupCount : GroupedTopNBenchmark.class.getField("groupCount").getAnnotationsByType(Param.class)[0] + .value()) { + for (String gkType : GroupedTopNBenchmark.class.getField("groupKeyType").getAnnotationsByType(Param.class)[0] + .value()) { + for (String gkCount : GroupedTopNBenchmark.class.getField("groupKeyCount").getAnnotationsByType(Param.class)[0] + .value()) { + run( + data, + Integer.parseInt(topCount), + Integer.parseInt(groupCount), + gkType, + Integer.parseInt(gkCount), + SELF_TEST_PAGES + ); + } + } + } + } + } + } catch (NoSuchFieldException e) { + throw new AssertionError(); + } + } + + @Param({ LONGS + ASC, LONGS + DESC, BYTES_REFS + ASC, LONGS + ASC + AND + LONGS + ASC, LONGS + ASC + AND + BYTES_REFS + ASC }) + public String data; + + @Param({ "1", "10", "1000" }) + public int topCount; + + @Param({ "10", "100", "1000" }) + public int groupCount; + + @Param({ LONGS, BYTES_REFS }) + public String groupKeyType; + + @Param({ "1", "2" }) + public int groupKeyCount; + + private static Operator operator(String data, int topCount, String groupKeyType, int groupKeyCount) { + String[] dataSpec = data.split("_and_"); + List elementTypes = new ArrayList<>(Arrays.stream(dataSpec).map(GroupedTopNBenchmark::elementType).toList()); + List encoders = new ArrayList<>(Arrays.stream(dataSpec).map(GroupedTopNBenchmark::encoder).toList()); + List sortOrders = IntStream.range(0, dataSpec.length).mapToObj(c -> sortOrder(c, dataSpec[c])).toList(); + + int[] groupKeys = new int[groupKeyCount]; + ElementType gkElementType = groupKeyElementType(groupKeyType); + for (int i = 0; i < groupKeyCount; i++) { + groupKeys[i] = elementTypes.size(); + elementTypes.add(gkElementType); + encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); + } + + return new GroupedTopNOperator( + blockFactory, + blockFactory.breaker(), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + 8 * 1024, + Long.MAX_VALUE + ); + } + + private static ElementType elementType(String data) { + return switch (data.replace(ASC, "").replace(DESC, "")) { + case LONGS -> ElementType.LONG; + case INTS -> ElementType.INT; + case DOUBLES -> ElementType.DOUBLE; + case BOOLEANS -> ElementType.BOOLEAN; + case BYTES_REFS -> ElementType.BYTES_REF; + default -> throw new IllegalArgumentException("unsupported data type [" + data + "]"); + }; + } + + private static TopNEncoder encoder(String data) { + return switch (data.replace(ASC, "").replace(DESC, "")) { + case LONGS, INTS, DOUBLES, BOOLEANS -> TopNEncoder.DEFAULT_SORTABLE; + case BYTES_REFS -> TopNEncoder.UTF8; + default -> throw new IllegalArgumentException("unsupported data type [" + data + "]"); + }; + } + + private static ElementType groupKeyElementType(String groupKeyType) { + return switch (groupKeyType) { + case LONGS -> ElementType.LONG; + case BYTES_REFS -> ElementType.BYTES_REF; + default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); + }; + } + + private static boolean ascDesc(String data) { + if (data.endsWith(ASC)) { + return true; + } else if (data.endsWith(DESC)) { + return false; + } else { + throw new IllegalArgumentException("data neither asc nor desc: " + data); + } + } + + private static TopNOperator.SortOrder sortOrder(int channel, String data) { + return new TopNOperator.SortOrder(channel, ascDesc(data), false); + } + + private static void checkExpected(int topCount, int groupCount, int numPages, List pages) { + int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); + long expectedOutput = 0; + for (int g = 0; g < effectiveGroupCount; g++) { + int rowsPerPage = BLOCK_LENGTH / effectiveGroupCount + (g < BLOCK_LENGTH % effectiveGroupCount ? 1 : 0); + long totalRowsForGroup = (long) rowsPerPage * numPages; + expectedOutput += Math.min(topCount, totalRowsForGroup); + } + long actualOutput = pages.stream().mapToLong(Page::getPositionCount).sum(); + if (expectedOutput != actualOutput) { + throw new AssertionError("expected [" + expectedOutput + "] but got [" + actualOutput + "]"); + } + } + + private static Page page(String data, int groupCount, String groupKeyType, int groupKeyCount) { + String[] dataSpec = data.split("_and_"); + int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); + int divisor = (int) Math.ceil(Math.sqrt(effectiveGroupCount)); + + Block[] blocks = new Block[dataSpec.length + groupKeyCount]; + for (int i = 0; i < dataSpec.length; i++) { + blocks[i] = block(dataSpec[i]); + } + for (int k = 0; k < groupKeyCount; k++) { + blocks[dataSpec.length + k] = groupKeyBlock(groupKeyType, effectiveGroupCount, divisor, k, groupKeyCount); + } + return new Page(blocks); + } + + private static Block block(String data) { + return switch (data.replace(ASC, "").replace(DESC, "")) { + case LONGS -> { + var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); + new Random().longs(BLOCK_LENGTH, 0, Long.MAX_VALUE).forEach(builder::appendLong); + yield builder.build(); + } + case INTS -> { + var builder = blockFactory.newIntBlockBuilder(BLOCK_LENGTH); + new Random().ints(BLOCK_LENGTH, 0, Integer.MAX_VALUE).forEach(builder::appendInt); + yield builder.build(); + } + case DOUBLES -> { + var builder = blockFactory.newDoubleBlockBuilder(BLOCK_LENGTH); + new Random().doubles(BLOCK_LENGTH, 0, Double.MAX_VALUE).forEach(builder::appendDouble); + yield builder.build(); + } + case BOOLEANS -> { + BooleanBlock.Builder builder = blockFactory.newBooleanBlockBuilder(BLOCK_LENGTH); + new Random().ints(BLOCK_LENGTH, 0, 1).forEach(i -> builder.appendBoolean(i == 1)); + yield builder.build(); + } + case BYTES_REFS -> { + BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(BLOCK_LENGTH); + new Random().ints(BLOCK_LENGTH, 0, Integer.MAX_VALUE) + .forEach(i -> builder.appendBytesRef(new BytesRef(Integer.toString(i)))); + yield builder.build(); + } + default -> throw new UnsupportedOperationException("unsupported data [" + data + "]"); + }; + } + + private static Block groupKeyBlock(String groupKeyType, int effectiveGroupCount, int divisor, int keyIndex, int groupKeyCount) { + return switch (groupKeyType) { + case LONGS -> { + var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); + for (int i = 0; i < BLOCK_LENGTH; i++) { + int groupId = i % effectiveGroupCount; + long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); + builder.appendLong(keyValue); + } + yield builder.build(); + } + case BYTES_REFS -> { + BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(BLOCK_LENGTH); + for (int i = 0; i < BLOCK_LENGTH; i++) { + int groupId = i % effectiveGroupCount; + long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); + builder.appendBytesRef(new BytesRef(Long.toString(keyValue))); + } + yield builder.build(); + } + default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); + }; + } + + @Benchmark + @OperationsPerInvocation(NUM_PAGES * BLOCK_LENGTH) + public void run() { + run(data, topCount, groupCount, groupKeyType, groupKeyCount, NUM_PAGES); + } + + private static void run(String data, int topCount, int groupCount, String groupKeyType, int groupKeyCount, int numPages) { + try (Operator operator = operator(data, topCount, groupKeyType, groupKeyCount)) { + Page page = page(data, groupCount, groupKeyType, groupKeyCount); + for (int i = 0; i < numPages; i++) { + operator.addInput(page.shallowCopy()); + } + operator.finish(); + List results = new ArrayList<>(); + Page p; + while ((p = operator.getOutput()) != null) { + results.add(p); + } + checkExpected(topCount, groupCount, numPages, results); + } + } +} diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmarkTests.java new file mode 100644 index 0000000000000..d592f46642584 --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmarkTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark._nightly.esql; + +import org.elasticsearch.test.ESTestCase; + +public class GroupedTopNBenchmarkTests extends ESTestCase { + public void test() { + GroupedTopNBenchmark.selfTest(); + } +} From 905b80a645723bce7fee0b74e095220fa1f612c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 5 Mar 2026 18:30:04 +0100 Subject: [PATCH 08/22] Simplified tests removing redundant methods --- .../topn/GroupedTopNOperatorTests.java | 141 +++--------------- .../operator/topn/TopNOperatorTests.java | 2 +- 2 files changed, 23 insertions(+), 120 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index 0e178b927db51..f99370aa02548 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; -import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; @@ -43,12 +42,10 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.LongStream; -import java.util.stream.Stream; import static org.elasticsearch.compute.data.ElementType.DOC; import static org.elasticsearch.compute.data.ElementType.LONG; import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_UNSORTABLE; -import static org.elasticsearch.compute.test.BlockTestUtils.append; import static org.elasticsearch.compute.test.BlockTestUtils.readInto; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -131,10 +128,6 @@ public void testBasicTopN() { assertThat(topNLong(values, 2, false, true), equalTo(Arrays.asList(null, null, 4L, 4L, 2L, 1L))); } - private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { - return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); - } - @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { return new TupleLongLongBlockSourceOperator( @@ -252,32 +245,8 @@ public void testSortedInputWithMultipleGroups() { List sortOrders = List.of(new SortOrder(0, true, true)); BlockFactory bf = driverContext().blockFactory(); - - // Page 1: sorted ASC by sort key - Page page1; - try ( - Block.Builder sortCol = ElementType.INT.newBlockBuilder(2, bf); - Block.Builder groupCol = ElementType.INT.newBlockBuilder(2, bf) - ) { - append(sortCol, 1); - append(sortCol, 3); - append(groupCol, 0); - append(groupCol, 1); - page1 = new Page(sortCol.build(), groupCol.build()); - } - - // Page 2: sorted ASC by sort key - Page page2; - try ( - Block.Builder sortCol = ElementType.INT.newBlockBuilder(2, bf); - Block.Builder groupCol = ElementType.INT.newBlockBuilder(2, bf) - ) { - append(sortCol, 2); - append(sortCol, 4); - append(groupCol, 0); - append(groupCol, 2); - page2 = new Page(sortCol.build(), groupCol.build()); - } + Page page1 = new Page(BlockUtils.fromList(bf, List.of(List.of(1, 0), List.of(3, 1)))); + Page page2 = new Page(BlockUtils.fromList(bf, List.of(List.of(2, 0), List.of(4, 2)))); List> actual = runGroupedTopN(List.of(page1, page2), 1, elementTypes, encoders, sortOrders, new int[] { 1 }); @@ -348,7 +317,7 @@ public void testMultipleMultivalueGroupKeys() { } public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { - topNShardContextManagementAux(2, Stream.generate(() -> true).limit(4).toList()); + topNShardContextManagementAux(2, List.of(true, true, true, true)); } public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { @@ -363,7 +332,12 @@ private void topNShardContextManagementAux(int limit, List expectedOpen Arrays.asList(new BlockUtils.Doc(3, 40, 400), -3L, 2L) ); - List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); + List refCountedList = List.of( + new SimpleRefCounted(), + new SimpleRefCounted(), + new SimpleRefCounted(), + new SimpleRefCounted() + ); var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); var pages = topNMultipleColumns( driverContext(), @@ -382,14 +356,19 @@ protected TestBlockBuilder getTestBlockBuilder(int b) { refCountedList.forEach(RefCounted::decRef); assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); - assertThat(pageToValues(pages), equalTo(computeTopN(values, 2, 1, limit, true))); - - for (var rc : refCountedList) { - assertFalse(rc.hasReferences()); + List> valuesAsObjects = values.stream().map(row -> row.stream().map(v -> (Object) v).toList()).toList(); + List> actual = new ArrayList<>(); + for (Page p : pages) { + actual.addAll(readAsRowsSingleValue(p)); } + assertThat(actual, equalTo(computeTopN(valuesAsObjects, List.of(2), List.of(new SortOrder(1, true, false)), limit))); } finally { Releasables.close(pages); } + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } } public void testRandomMultipleColumns() { @@ -433,7 +412,7 @@ public void testRandomMultipleColumns() { topCount, randomBlocksResult.elementTypes, randomBlocksResult.encoders, - uniqueOrders.stream().toList(), + uniqueOrders, groupKeys.stream().mapToInt(Integer::intValue).toArray(), rows, Long.MAX_VALUE @@ -488,16 +467,6 @@ private static Comparator> comparatorFromSortOrders(List }; } - private static final Comparator> TIE_BREAKING_COMPARATOR = (row1, row2) -> { - for (int i = 0; i < row1.size(); i++) { - int cmp = compareValues(new SortOrder(i, true, true)).compare(row1.get(i), row2.get(i)); - if (cmp != 0) { - return cmp; - } - } - return 0; - }; - private static boolean isSorted(List> values, Comparator> comparator) { return IntStream.range(1, values.size()).allMatch(i -> comparator.compare(values.get(i - 1), values.get(i)) <= 0); } @@ -511,55 +480,19 @@ private static Comparator compareValues(SortOrder order) { private static final Comparator CASTING_COMPARATOR = (o1, o2) -> ((Comparable) o1).compareTo(o2); private List> computeTopN(List> inputValues, int limit, boolean ascendingOrder) { - return computeTopN(inputValues.stream().map(e -> Arrays.asList(e.v1(), e.v2())).toList(), 1, 0, limit, ascendingOrder).stream() + List> rows = inputValues.stream().map(e -> Arrays.asList(e.v1(), e.v2())).toList(); + return computeTopN(rows, List.of(1), List.of(new SortOrder(0, ascendingOrder, false)), limit).stream() .map(l -> Tuple.tuple((Long) l.get(0), (Long) l.get(1))) .toList(); } - private List> computeTopN( - List> inputValues, - int groupChannel, - int sortChannel, - int limit, - boolean ascendingOrder - ) { - List> singleValueInput = new ArrayList<>(); - for (List row : inputValues) { - List rowAsObject = row.stream().map(v -> (Object) v).toList(); - singleValueInput.add(rowAsObject); - } - List sortOrders = List.of(new SortOrder(sortChannel, ascendingOrder, false)); - return computeTopN(singleValueInput, List.of(groupChannel), sortOrders, limit); - } - private List> computeTopN( List> inputValues, List groupChannels, List sortOrders, int limit ) { - Comparator> comparator = (row1, row2) -> { - for (SortOrder order : sortOrders) { - Object v1 = row1.get(order.channel()); - Object v2 = row2.get(order.channel()); - boolean firstIsNull = v1 == null; - boolean secondIsNull = v2 == null; - - if (firstIsNull || secondIsNull) { - int nullCompare = Boolean.compare(firstIsNull, secondIsNull) * (order.nullsFirst() ? -1 : 1); - if (nullCompare != 0) { - return nullCompare; - } - continue; - } - - int cmp = CASTING_COMPARATOR.compare(v1, v2); - if (cmp != 0) { - return order.asc() ? cmp : -cmp; - } - } - return 0; - }; + Comparator> comparator = comparatorFromSortOrders(sortOrders); Map, List>> grouped = inputValues.stream() .collect(Collectors.groupingBy(row -> groupChannels.stream().map(row::get).toList())); @@ -573,34 +506,4 @@ private List> computeTopN( return topNExpectedValues; } - private static List> pageToValues(List pages) { - var result = new ArrayList>(); - for (Page page : pages) { - var blocks = IntStream.range(0, page.getBlockCount()).mapToObj(page::getBlock).toList(); - result.addAll( - IntStream.range(0, page.getPositionCount()) - .mapToObj(position -> blocks.stream().map(block -> getBlockValue(block, position)).toList()) - .toList() - ); - page.releaseBlocks(); - } - - return result; - } - - private static Object getBlockValue(Block block, int position) { - return block.isNull(position) ? null : switch (block) { - case LongBlock longBlock -> longBlock.getLong(position); - case DocBlock docBlock -> { - var vector = docBlock.asVector(); - yield new BlockUtils.Doc( - vector.shards().getInt(position), - vector.segments().getInt(position), - vector.docs().getInt(position) - ); - } - default -> throw new IllegalArgumentException("Unsupported block type: " + block.getClass()); - }; - } - } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 972b7490b1207..1f9718db3c771 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -739,7 +739,7 @@ public void testBasicTopN() { assertThat(topNLong(values, 100, false, true), equalTo(Arrays.asList(null, null, 100L, 20L, 10L, 5L, 4L, 4L, 2L, 1L))); } - private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { + protected List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); } From 7c57e4312265130efa9c5fc8489552de3626aae4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 5 Mar 2026 18:35:33 +0100 Subject: [PATCH 09/22] Reorder private method --- .../topn/GroupedTopNOperatorTests.java | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index f99370aa02548..1132f6cbfa282 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -192,41 +192,6 @@ protected void assertSimpleOutput(List input, List results) { ); } - private List> runGroupedTopN( - List pages, - int topCount, - List elementTypes, - List encoders, - List sortOrders, - int[] groupKeys - ) { - DriverContext driverContext = driverContext(); - List> actual = new ArrayList<>(); - try ( - Driver driver = TestDriverFactory.create( - driverContext, - new CannedSourceOperator(pages.iterator()), - List.of( - new GroupedTopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), - topCount, - elementTypes, - encoders, - sortOrders, - groupKeys, - randomPageSize(), - Long.MAX_VALUE - ) - ), - new PageConsumerOperator(p -> readInto(actual, p)) - ) - ) { - new TestDriverRunner().run(driver); - } - return actual; - } - /** * Tests that the SORTED input ordering optimization short-circuiting addInput() doesn't incorrectly skip rows * belonging to groups not yet populated when another group's row is rejected. @@ -454,6 +419,41 @@ public void testRandomMultipleColumns() { assertThat(actualCounts, equalTo(expectedCounts)); } + private List> runGroupedTopN( + List pages, + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys + ) { + DriverContext driverContext = driverContext(); + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(pages.iterator()), + List.of( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + groupKeys, + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + return actual; + } + private static Comparator> comparatorFromSortOrders(List sortOrders) { return (row1, row2) -> { assertEquals(row1.size(), row2.size()); From ed995a0df849f0303dbaecb460f37a1f273beb07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Mar 2026 14:32:44 +0100 Subject: [PATCH 10/22] Javadoc and renames, and removed outdated test --- .../compute/operator/topn/GroupedQueue.java | 5 +- .../operator/topn/GroupedQueueTests.java | 6 +- .../operator/topn/GroupedRowTests.java | 88 ------------------- 3 files changed, 4 insertions(+), 95 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java index b1ee79748a0be..a4576d041091e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java @@ -20,10 +20,7 @@ import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; /** - * A queue that maintains a separate {@link TopNQueue} per group, indexed by integer group IDs - * assigned by a {@link org.elasticsearch.compute.aggregation.blockhash.BlockHash}. - * Uses a {@link BigArrays}-backed {@link ObjectArray} for better performance and circuit - * breaker integration. + * A queue that maintains a separate {@link TopNQueue} per group, indexed by group IDs. */ class GroupedQueue implements Accountable, Releasable { private static final long SHALLOW_SIZE = shallowSizeOfInstance(GroupedQueue.class); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java index 4192d8ba268f4..ca5b248b66d2c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -55,7 +55,7 @@ public void testCleanup() { } } - public void testAddWhenHeapNotFull() { + public void testAddWhenQueueNotFull() { int topCount = 5; try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { for (int i = 0; i < topCount; i++) { @@ -67,7 +67,7 @@ public void testAddWhenHeapNotFull() { } } - public void testAddWhenHeapFullAndRowQualifies() { + public void testAddWhenQueueFullAndRowQualifies() { int topCount = 3; try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { fillQueueToCapacity(queue, topCount); @@ -78,7 +78,7 @@ public void testAddWhenHeapFullAndRowQualifies() { } } - public void testAddWhenHeapFullAndRowDoesNotQualify() { + public void testAddWhenQueueFullAndRowDoesNotQualify() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 3)) { addRows(queue, 0, 30, 40, 50); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java deleted file mode 100644 index 5981e6c90298a..0000000000000 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedRowTests.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.operator.topn; - -import org.apache.lucene.tests.util.RamUsageTester; -import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.common.breaker.NoopCircuitBreaker; -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.PageCacheRecycler; -import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; -import org.elasticsearch.test.ESTestCase; - -import static org.hamcrest.Matchers.equalTo; - -public class GroupedRowTests extends ESTestCase { - private final CircuitBreaker breaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST); - - public void testCloseReleasesAllTestsNoPreAllocation() throws Exception { - BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); - CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); - var row = new TopNRow(breaker, 0, 0); - row.close(); - MockBigArrays.ensureAllArraysAreReleased(); - assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); - } - - public void testCloseReleasesAllTestsWithPreAllocation() throws Exception { - BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); - CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); - var row = new TopNRow(breaker, 16, 32); - row.close(); - MockBigArrays.ensureAllArraysAreReleased(); - assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); - } - - public void testRamBytesUsedEmpty() { - var row = new TopNRow(breaker, 0, 0); - assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); - } - - public void testRamBytesUsedSmall() { - var row = new TopNRow(breaker, 0, 0); - row.keys.append(randomByte()); - row.values.append(randomByte()); - assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); - } - - public void testRamBytesUsedBig() { - var row = new TopNRow(breaker, 0, 0); - for (int i = 0; i < 10000; i++) { - row.keys.append(randomByte()); - row.values.append(randomByte()); - } - assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); - } - - public void testRamBytesUsedPreAllocated() { - var row = new TopNRow(breaker, 64, 128); - assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); - } - - private long expectedRamBytesUsed(TopNRow row) { - var expected = RamUsageTester.ramUsed(row); - expected -= RamUsageTester.ramUsed(breaker); - expected -= sharedRowBytes(); - expected += undercountedBytesForRow(row); - return expected; - } - - private static long sharedRowBytes() { - return RamUsageTester.ramUsed("topn"); - } - - static long undercountedBytesForRow(TopNRow row) { - return emptyByteArrayOverhead(row.values); - } - - private static long emptyByteArrayOverhead(BreakingBytesRefBuilder builder) { - return builder.bytes().length == 0 ? RamUsageTester.ramUsed(new byte[0]) : 0L; - } -} From f5f96fb698e86194fe1314869381ff23c197d614 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Mar 2026 14:39:58 +0100 Subject: [PATCH 11/22] Rename test operator --- .../compute/operator/topn/TopNOperatorTests.java | 5 ++--- ...rceBuilder.java => AbstractTypedBlockSourceOperator.java} | 4 ++-- .../operator/blocksource/ListRowsBlockSourceOperator.java | 4 ++-- .../blocksource/TupleAbstractBlockSourceOperator.java | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) rename x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/{TypedAbstractBlockSourceBuilder.java => AbstractTypedBlockSourceOperator.java} (76%) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 1f9718db3c771..fbda4bddaa593 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -43,7 +43,7 @@ import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.compute.test.TestDriverFactory; import org.elasticsearch.compute.test.TestDriverRunner; -import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleDocLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; @@ -99,7 +99,6 @@ import static org.elasticsearch.compute.test.BlockTestUtils.randomValue; import static org.elasticsearch.compute.test.BlockTestUtils.readInto; import static org.elasticsearch.core.Tuple.tuple; -import static org.elasticsearch.test.ESTestCase.between; import static org.elasticsearch.test.ListMatcher.matchesList; import static org.elasticsearch.test.MapMatcher.assertMap; import static org.hamcrest.Matchers.both; @@ -1203,7 +1202,7 @@ protected List> topNTwoLongColumns( protected List topNMultipleColumns( DriverContext driverContext, - TypedAbstractBlockSourceBuilder sourceOperator, + AbstractTypedBlockSourceOperator sourceOperator, int limit, List encoder, List sortOrders, diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java similarity index 76% rename from x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java rename to x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java index 42a022ce04a7b..281aa91846ae4 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TypedAbstractBlockSourceBuilder.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java @@ -13,8 +13,8 @@ import java.util.List; -public abstract class TypedAbstractBlockSourceBuilder extends AbstractBlockSourceOperator { - protected TypedAbstractBlockSourceBuilder(BlockFactory blockFactory, int maxPagePositions) { +public abstract class AbstractTypedBlockSourceOperator extends AbstractBlockSourceOperator { + protected AbstractTypedBlockSourceOperator(BlockFactory blockFactory, int maxPagePositions) { super(blockFactory, maxPagePositions); } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java index eab815f4bf1df..68cbc83b5004b 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.test.TestBlockBuilder; -import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.core.Releasables; import java.util.List; @@ -23,7 +23,7 @@ /** * A source operator whose output is rows specified as a list {@link List} values. */ -public class ListRowsBlockSourceOperator extends TypedAbstractBlockSourceBuilder { +public class ListRowsBlockSourceOperator extends AbstractTypedBlockSourceOperator { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List types; diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java index e13f73785a2bd..40b7e0ac4b7be 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java @@ -11,7 +11,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.test.TypedAbstractBlockSourceBuilder; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -20,7 +20,7 @@ * A source operator whose output is the given tuple values. This operator produces pages * with two Blocks. The returned pages preserve the order of values as given in the in initial list. */ -public abstract class TupleAbstractBlockSourceOperator extends TypedAbstractBlockSourceBuilder { +public abstract class TupleAbstractBlockSourceOperator extends AbstractTypedBlockSourceOperator { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List> values; From bce5c4581e238bdc37db612965f175f1052ff5b8 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 6 Mar 2026 14:07:32 +0000 Subject: [PATCH 12/22] [CI] Auto commit changes from spotless --- .../elasticsearch/compute/operator/topn/TopNOperatorTests.java | 2 +- .../test/operator/blocksource/ListRowsBlockSourceOperator.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index fbda4bddaa593..5595fcc646e49 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -37,13 +37,13 @@ import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.topn.TopNOperator.InputOrdering; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.compute.test.CannedSourceOperator; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.TestBlockBuilder; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.compute.test.TestDriverFactory; import org.elasticsearch.compute.test.TestDriverRunner; -import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleDocLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java index 68cbc83b5004b..481cf9c877896 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java @@ -11,8 +11,8 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.test.TestBlockBuilder; import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; +import org.elasticsearch.compute.test.TestBlockBuilder; import org.elasticsearch.core.Releasables; import java.util.List; From 476e635434e0e9ecbd82dd7884d65ad4035f9dd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Mar 2026 18:00:15 +0100 Subject: [PATCH 13/22] Fix double closing and simplify GroupedQueue tests --- .../common/util/MockBigArrays.java | 9 +- .../compute/operator/topn/GroupedQueue.java | 12 +- .../operator/topn/GroupedTopNOperator.java | 15 +- .../operator/topn/GroupedQueueTests.java | 191 ++++-------------- 4 files changed, 56 insertions(+), 171 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/common/util/MockBigArrays.java b/test/framework/src/main/java/org/elasticsearch/common/util/MockBigArrays.java index de87772d5ae82..bd73294ddde23 100644 --- a/test/framework/src/main/java/org/elasticsearch/common/util/MockBigArrays.java +++ b/test/framework/src/main/java/org/elasticsearch/common/util/MockBigArrays.java @@ -142,8 +142,15 @@ public static void ensureAllArraysAreReleased() throws Exception { * Create {@linkplain BigArrays} with a configured limit. */ public MockBigArrays(PageCacheRecycler recycler, ByteSizeValue limit) { + this(recycler, new LimitedBreaker(CircuitBreaker.REQUEST, limit)); + } + + /** + * Create {@linkplain BigArrays} with a configured request circuit breaker. + */ + public MockBigArrays(PageCacheRecycler recycler, CircuitBreaker breaker) { this(recycler, mock(CircuitBreakerService.class), true); - when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(new LimitedBreaker(CircuitBreaker.REQUEST, limit)); + when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker); } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java index a4576d041091e..5f44e948e9ce6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java @@ -53,17 +53,7 @@ int size() { return totalSize; } - /** - * Attempts to add the row to the per-group queue identified by {@code groupId}. - * @return If the row was added and the queue was full, the evicted row. - * If the row was added and it wasn't full, {@code null}. - * If the row wasn't added, the input row. - */ - TopNRow addRow(long groupId, TopNRow row) { - return getOrCreateQueue(groupId).addRow(row); - } - - private TopNQueue getOrCreateQueue(long groupId) { + TopNQueue getOrCreateQueue(long groupId) { if (groupId >= queues.size()) { queues = bigArrays.grow(queues, groupId + 1); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java index 6e57b3c118285..d7f7147ccc489 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -195,11 +195,18 @@ private void processRow(TopNOperator.RowFiller rowFiller, int position, long gro } rowFiller.writeKey(position, spare); - var nextSpare = inputQueue.addRow(groupId, spare); - if (nextSpare != spare) { - var insertedRow = spare; + // Write values BEFORE modifying the queue so that if writeValues throws (e.g. circuit breaker), + // spare is not left in both the queue and the spare field (which would double-close). + TopNQueue queue = inputQueue.getOrCreateQueue(groupId); + if (queue.size() < queue.topCount) { + rowFiller.writeValues(position, spare); + queue.add(spare); + spare = null; + } else if (queue.lessThan(queue.top(), spare)) { + rowFiller.writeValues(position, spare); + TopNRow nextSpare = queue.top(); + queue.updateTop(spare); spare = nextSpare; - rowFiller.writeValues(position, insertedRow); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java index ca5b248b66d2c..58908cff6d2ac 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -18,141 +18,38 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; -import org.junit.After; import java.util.List; -import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.hamcrest.Matchers.sameInstance; public class GroupedQueueTests extends ESTestCase { - private final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(1)); - private final CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + private final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofMb(1)); + private final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, breaker); private final BlockFactory blockFactory = new BlockFactory(breaker, bigArrays); - @After - public void allMemoryReleased() throws Exception { - MockBigArrays.ensureAllArraysAreReleased(); - - assertThat("Not all memory was released", breaker.getUsed(), equalTo(0L)); - assertThat("Not all blocks were released", blockFactory.breaker().getUsed(), equalTo(0L)); - } - - public void testCleanup() { - int topCount = 5; - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { - assertThat(queue.size(), equalTo(0)); - - for (int i = 0; i < topCount * 2; i++) { - addRow(queue, i % 3, i * 10); - } - } - } - - public void testAddWhenQueueNotFull() { - int topCount = 5; - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { - for (int i = 0; i < topCount; i++) { - TopNRow row = createRow(breaker, i * 10); - TopNRow result = queue.addRow(i % 2, row); - assertThat(result, nullValue()); - assertThat(queue.size(), equalTo(i + 1)); - } - } - } - - public void testAddWhenQueueFullAndRowQualifies() { - int topCount = 3; - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, topCount)) { - fillQueueToCapacity(queue, topCount); - - try (TopNRow evicted = queue.addRow(0, createRow(breaker, 5))) { - assertRowValues(evicted, 20, 40); - } - } - } - - public void testAddWhenQueueFullAndRowDoesNotQualify() { - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 3)) { - addRows(queue, 0, 30, 40, 50); - - try (TopNRow row = createRow(breaker, 60)) { - TopNRow result = queue.addRow(0, row); - assertThat(result, sameInstance(row)); - } - } - } - - public void testAddWithDifferentGroupKeys() { + public void testGroupIsolation() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 2)) { - assertThat(queue.addRow(0, createRow(breaker, 10)), nullValue()); - assertThat(queue.addRow(1, createRow(breaker, 20)), nullValue()); - assertThat(queue.addRow(0, createRow(breaker, 30)), nullValue()); - assertThat(queue.addRow(1, createRow(breaker, 40)), nullValue()); - assertThat(queue.size(), equalTo(4)); - - try (TopNRow evicted = queue.addRow(0, createRow(breaker, 5))) { - assertThat(evicted, notNullValue()); - assertRowValues(evicted, 30, 60); - } - try (TopNRow evicted = queue.addRow(1, createRow(breaker, 15))) { - assertThat(evicted, notNullValue()); - assertRowValues(evicted, 40, 80); - } - assertThat(queue.size(), equalTo(4)); - - try (TopNRow row = queue.addRow(0, createRow(breaker, 50))) { - assertThat(row, notNullValue()); - assertRowValues(row, 50, 100); - } - try (TopNRow row = queue.addRow(1, createRow(breaker, 50))) { - assertThat(row, notNullValue()); - assertRowValues(row, 50, 100); - } + addRows(queue, 0, 10, 30, 5); + addRows(queue, 1, 20, 40, 15); assertThat(queue.size(), equalTo(4)); + assertQueueContents(queue, List.of(5, 10, 15, 20)); } } - public void testRamBytesUsedEmpty() { - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { - assertRamBytesUsedConsistent(queue); - } - } - - /** - * Verifies that ramBytesUsed() accounts for at least the shallow size and grows with content. - * We can't use RamUsageTester for BigArrays-backed structures due to module access restrictions. - */ - private void assertRamBytesUsedConsistent(GroupedQueue queue) { - long reported = queue.ramBytesUsed(); - assertThat("ramBytesUsed should be positive", reported, greaterThan(0L)); - } - - public void testRamBytesUsedPartiallyFilled() { + public void testRamBytesUsed() { try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { long emptySize = queue.ramBytesUsed(); - addRows(queue, 0, 10, 20, 30); - long filledSize = queue.ramBytesUsed(); - assertThat("RAM usage should grow after adding rows", filledSize, greaterThan(emptySize)); - } - } + assertThat("ramBytesUsed should be positive", emptySize, greaterThan(0L)); - public void testRamBytesUsedAtCapacity() { - try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { - long emptySize = queue.ramBytesUsed(); - addRows(queue, 0, 10, 20, 30, 40, 50); + addRows(queue, 0, 10, 20, 30); long oneGroupSize = queue.ramBytesUsed(); - addRows(queue, 1, 10, 20, 30, 40, 50); - addRows(queue, 2, 10, 20, 30, 40, 50); - long threeGroupSize = queue.ramBytesUsed(); assertThat("RAM should grow with first group", oneGroupSize, greaterThan(emptySize)); - assertThat("RAM should grow with more groups", threeGroupSize, greaterThan(oneGroupSize)); + + addRows(queue, 1, 10, 20, 30); + assertThat("RAM should grow with more groups", queue.ramBytesUsed(), greaterThan(oneGroupSize)); } } @@ -161,19 +58,7 @@ public void testPopAllSortedBySortKey() { addRows(queue, 0, 30, 10, 50); addRows(queue, 1, 20, 40); addRows(queue, 2, 15, 25, 35); - assertQueueContents( - queue, - List.of( - Tuple.tuple(0, 10), - Tuple.tuple(2, 15), - Tuple.tuple(1, 20), - Tuple.tuple(2, 25), - Tuple.tuple(0, 30), - Tuple.tuple(2, 35), - Tuple.tuple(1, 40), - Tuple.tuple(0, 50) - ) - ); + assertQueueContents(queue, List.of(10, 15, 20, 25, 30, 35, 40, 50)); } } @@ -198,6 +83,29 @@ private TopNRow createRow(CircuitBreaker breaker, int sortKey) { return row; } + private void addRows(GroupedQueue queue, int groupKey, int... values) { + for (int value : values) { + addRow(queue, groupKey, value); + } + } + + private void addRow(GroupedQueue queue, int groupKey, int value) { + TopNRow row = createRow(breaker, value); + Releasables.close(queue.getOrCreateQueue(groupKey).addRow(row)); + } + + private static final List SORT_ORDERS = List.of(new TopNOperator.SortOrder(1, true, false)); + + private void assertQueueContents(GroupedQueue queue, List expectedSortKeys) { + assertThat(queue.size(), equalTo(expectedSortKeys.size())); + List actual = queue.popAll(); + for (int i = 0; i < expectedSortKeys.size(); i++) { + int sortKey = expectedSortKeys.get(i); + assertRowValues(actual.get(i), sortKey, sortKey * 2); + } + Releasables.close(actual); + } + private static void assertRowValues(TopNRow row, int expectedSortKey, int expectedValue) { BytesRef keys = row.keys.bytesRefView(); assertThat( @@ -213,31 +121,4 @@ private static void assertRowValues(TopNRow row, int expectedSortKey, int expect assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader), equalTo(expectedValue)); } - - private void addRow(GroupedQueue queue, int groupKey, int value) { - TopNRow row = createRow(breaker, value); - Releasables.close(queue.addRow(groupKey, row)); - } - - private void fillQueueToCapacity(GroupedQueue queue, int capacity) { - addRows(queue, 0, IntStream.range(0, capacity).map(i -> i * 10).toArray()); - } - - private void addRows(GroupedQueue queue, int groupKey, int... values) { - for (int value : values) { - addRow(queue, groupKey, value); - } - } - - private static final List SORT_ORDERS = List.of(new TopNOperator.SortOrder(1, true, false)); - - private void assertQueueContents(GroupedQueue queue, List> groupAndSortKeys) { - assertThat(queue.size(), equalTo(groupAndSortKeys.size())); - List actual = queue.popAll(); - for (int i = 0; i < groupAndSortKeys.size(); i++) { - Tuple expectedTuple = groupAndSortKeys.get(i); - assertRowValues(actual.get(i), expectedTuple.v2(), expectedTuple.v2() * 2); - } - Releasables.close(actual); - } } From 7af754101f49b48f42b11698b2e9fef807a6f241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 9 Mar 2026 17:13:20 +0100 Subject: [PATCH 14/22] Changed status groupCount to int and fixed TopNRow hashCode --- .../compute/operator/topn/GroupedTopNOperator.java | 2 +- .../operator/topn/GroupedTopNOperatorStatus.java | 10 +++++----- .../elasticsearch/compute/operator/topn/TopNRow.java | 3 +-- .../operator/topn/GroupedTopNOperatorStatusTests.java | 6 +++--- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java index d7f7147ccc489..7e886135bda4d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -277,7 +277,7 @@ public Status status() { receiveNanos, emitNanos, inputQueue != null ? inputQueue.size() : 0, - keysHash != null ? keysHash.size() : 0, + keysHash != null ? (int) keysHash.size() : 0, ramBytesUsed(), pagesReceived, pagesEmitted, diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java index 6e641c1316fa9..15e524b21d0a2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java @@ -30,7 +30,7 @@ public class GroupedTopNOperatorStatus implements Operator.Status { private final long receiveNanos; private final long emitNanos; private final int occupiedRows; - private final long groupCount; + private final int groupCount; private final long ramBytesUsed; private final int pagesReceived; private final int pagesEmitted; @@ -41,7 +41,7 @@ public GroupedTopNOperatorStatus( long receiveNanos, long emitNanos, int occupiedRows, - long groupCount, + int groupCount, long ramBytesUsed, int pagesReceived, int pagesEmitted, @@ -63,7 +63,7 @@ public GroupedTopNOperatorStatus( this.receiveNanos = in.readVLong(); this.emitNanos = in.readVLong(); this.occupiedRows = in.readVInt(); - this.groupCount = in.readVLong(); + this.groupCount = in.readVInt(); this.ramBytesUsed = in.readVLong(); this.pagesReceived = in.readVInt(); @@ -78,7 +78,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(emitNanos); out.writeVInt(occupiedRows); - out.writeVLong(groupCount); + out.writeVInt(groupCount); out.writeVLong(ramBytesUsed); out.writeVInt(pagesReceived); @@ -104,7 +104,7 @@ public int occupiedRows() { return occupiedRows; } - public long groupCount() { + public int groupCount() { return groupCount; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java index 779c6f9fafd7f..951d85145ed16 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java @@ -17,7 +17,6 @@ import org.elasticsearch.core.Releasables; import java.util.Arrays; -import java.util.Objects; /** * A single row in a top-N operation. Stores encoded sort keys and values. @@ -112,7 +111,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hashCode(keys); + return keys.bytesRefView().hashCode(); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java index 25003e93124fc..470061010b24d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java @@ -52,7 +52,7 @@ protected GroupedTopNOperatorStatus createTestInstance() { randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeInt(), - randomNonNegativeLong(), + randomNonNegativeInt(), randomNonNegativeLong(), randomNonNegativeInt(), randomNonNegativeInt(), @@ -66,7 +66,7 @@ protected GroupedTopNOperatorStatus mutateInstance(GroupedTopNOperatorStatus ins long receiveNanos = instance.receiveNanos(); long emitNanos = instance.emitNanos(); int occupiedRows = instance.occupiedRows(); - long groupCount = instance.groupCount(); + int groupCount = instance.groupCount(); long ramBytesUsed = instance.ramBytesUsed(); int pagesReceived = instance.pagesReceived(); int pagesEmitted = instance.pagesEmitted(); @@ -83,7 +83,7 @@ protected GroupedTopNOperatorStatus mutateInstance(GroupedTopNOperatorStatus ins occupiedRows = randomValueOtherThan(occupiedRows, ESTestCase::randomNonNegativeInt); break; case 3: - groupCount = randomValueOtherThan(groupCount, ESTestCase::randomNonNegativeLong); + groupCount = randomValueOtherThan(groupCount, ESTestCase::randomNonNegativeInt); break; case 4: ramBytesUsed = randomValueOtherThan(ramBytesUsed, ESTestCase::randomNonNegativeLong); From b727c7f78e6893159d13ebbf9f751993d7e8620f Mon Sep 17 00:00:00 2001 From: ncordon Date: Tue, 10 Mar 2026 11:15:01 +0100 Subject: [PATCH 15/22] Changes PositionKeyEncoder by GroupKeyEncoder --- .../compute/operator/GroupKeyEncoder.java | 15 +- .../compute/operator/PositionKeyEncoder.java | 146 ------------------ .../operator/topn/GroupedTopNOperator.java | 20 ++- .../topn/GroupedTopNOperatorTests.java | 16 +- 4 files changed, 41 insertions(+), 156 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java index 121c098f61aff..25c7e2480b700 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java @@ -7,7 +7,9 @@ package org.elasticsearch.compute.operator; +import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; @@ -29,7 +31,9 @@ * in block iteration order. This means {@code [1, 2]} and {@code [2, 1]} produce different keys. * Null positions are encoded as a value count of zero. */ -public class GroupKeyEncoder implements Releasable { +public class GroupKeyEncoder implements Accountable, Releasable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupKeyEncoder.class); private static final DefaultUnsortableTopNEncoder encoder = TopNEncoder.DEFAULT_UNSORTABLE; @@ -113,6 +117,15 @@ private void encodeBlock(Block block, ElementType type, int position) { } } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.shallowSizeOf(elementTypes); + size += scratch.ramBytesUsed(); + size += RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); + return size; + } + @Override public void close() { scratch.close(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java deleted file mode 100644 index eb1a71553ac71..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/PositionKeyEncoder.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.operator; - -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; -import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.FloatBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.Page; - -import java.util.List; - -/** - * Encodes the values at a given position across multiple blocks into a single {@link BytesRef} composite key. - * Multivalued positions are serialized with list semantics: the value count is written first, then each value - * in block iteration order. This means {@code [1, 2]} and {@code [2, 1]} produce different keys. - * Null positions are encoded as a value count of zero. - */ -public class PositionKeyEncoder implements Accountable { - - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(PositionKeyEncoder.class); - - private final int[] groupChannels; - private final ElementType[] elementTypes; - private final BytesRefBuilder scratch = new BytesRefBuilder(); - private final BytesRef scratchBytesRef = new BytesRef(); - - public PositionKeyEncoder(int[] groupChannels, List elementTypes) { - this.groupChannels = groupChannels; - this.elementTypes = new ElementType[groupChannels.length]; - for (int i = 0; i < groupChannels.length; i++) { - this.elementTypes[i] = elementTypes.get(groupChannels[i]); - } - } - - /** - * Encode the group key for the given position from the page into a {@link BytesRef}. - * The returned reference is only valid until the next call to {@code encode}. - */ - public BytesRef encode(Page page, int position) { - scratch.clear(); - for (int i = 0; i < groupChannels.length; i++) { - Block block = page.getBlock(groupChannels[i]); - encodeBlock(block, elementTypes[i], position); - } - return scratch.get(); - } - - private void encodeBlock(Block block, ElementType type, int position) { - if (block.isNull(position)) { - writeVInt(0); - return; - } - int firstValueIndex = block.getFirstValueIndex(position); - int valueCount = block.getValueCount(position); - writeVInt(valueCount); - switch (type) { - case INT -> { - IntBlock b = (IntBlock) block; - for (int v = 0; v < valueCount; v++) { - writeInt(b.getInt(firstValueIndex + v)); - } - } - case LONG -> { - LongBlock b = (LongBlock) block; - for (int v = 0; v < valueCount; v++) { - writeLong(b.getLong(firstValueIndex + v)); - } - } - case DOUBLE -> { - DoubleBlock b = (DoubleBlock) block; - for (int v = 0; v < valueCount; v++) { - writeLong(Double.doubleToLongBits(b.getDouble(firstValueIndex + v))); - } - } - case FLOAT -> { - FloatBlock b = (FloatBlock) block; - for (int v = 0; v < valueCount; v++) { - writeInt(Float.floatToIntBits(b.getFloat(firstValueIndex + v))); - } - } - case BOOLEAN -> { - BooleanBlock b = (BooleanBlock) block; - for (int v = 0; v < valueCount; v++) { - scratch.append((byte) (b.getBoolean(firstValueIndex + v) ? 1 : 0)); - } - } - case BYTES_REF -> { - BytesRefBlock b = (BytesRefBlock) block; - for (int v = 0; v < valueCount; v++) { - BytesRef ref = b.getBytesRef(firstValueIndex + v, scratchBytesRef); - writeVInt(ref.length); - scratch.append(ref.bytes, ref.offset, ref.length); - } - } - case NULL -> { - // already handled by isNull above; nothing extra to write - } - default -> throw new IllegalArgumentException("unsupported element type for group key encoding: " + type); - } - } - - private void writeVInt(int value) { - while ((value & ~0x7F) != 0) { - scratch.append((byte) ((value & 0x7F) | 0x80)); - value >>>= 7; - } - scratch.append((byte) value); - } - - private void writeInt(int value) { - scratch.append((byte) (value >> 24)); - scratch.append((byte) (value >> 16)); - scratch.append((byte) (value >> 8)); - scratch.append((byte) value); - } - - private void writeLong(long value) { - writeInt((int) (value >> 32)); - writeInt((int) value); - } - - @Override - public long ramBytesUsed() { - long size = SHALLOW_SIZE; - size += RamUsageEstimator.sizeOf(groupChannels); - size += RamUsageEstimator.shallowSizeOf(elementTypes); - size += RamUsageEstimator.shallowSizeOfInstance(BytesRefBuilder.class); - size += RamUsageEstimator.sizeOf(scratch.bytes()); - size += RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); - return size; - } -} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java index 7e886135bda4d..f252510fe70bf 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -17,9 +17,10 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.GroupKeyEncoder; import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.PositionKeyEncoder; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; @@ -28,7 +29,7 @@ /** * A top-N operator for grouped (SORT + LIMIT BY) queries. Maintains per-group priority queues - * using a {@link PositionKeyEncoder} to map group key columns to integer group IDs. + * using a {@link GroupKeyEncoder} to map group key columns to integer group IDs. *

* Group keys use list semantics for multivalues: {@code [1,2]} and {@code [2,1]} are different groups. *

@@ -65,6 +66,9 @@ public record GroupedTopNOperatorFactory( @Override public GroupedTopNOperator get(DriverContext driverContext) { + var scratch = new BreakingBytesRefBuilder(driverContext.breaker(), "group-key-encoder"); + int[] groupKeysArray = groupKeys.stream().mapToInt(Integer::intValue).toArray(); + var keyEncoder = new GroupKeyEncoder(groupKeysArray, elementTypes, scratch); return new GroupedTopNOperator( driverContext.blockFactory(), driverContext.breaker(), @@ -72,7 +76,8 @@ public GroupedTopNOperator get(DriverContext driverContext) { elementTypes, encoders, sortOrders, - groupKeys.stream().mapToInt(Integer::intValue).toArray(), + groupKeysArray, + keyEncoder, maxPageSize, jumboPageBytes ); @@ -104,7 +109,7 @@ public String describe() { private final List sortOrders; private final int[] groupKeys; private final boolean[] channelInKey; - private final PositionKeyEncoder keyEncoder; + private final GroupKeyEncoder keyEncoder; private BytesRefHashTable keysHash; private GroupedQueue inputQueue; @@ -127,6 +132,7 @@ public GroupedTopNOperator( List encoders, List sortOrders, int[] groupKeys, + GroupKeyEncoder keyEncoder, int maxPageSize, long jumboPageBytes ) { @@ -139,10 +145,10 @@ public GroupedTopNOperator( success = true; } finally { if (success == false) { - Releasables.close(keysHash, inputQueue); + Releasables.close(keyEncoder, keysHash, inputQueue); } } - this.keyEncoder = new PositionKeyEncoder(groupKeys, elementTypes); + this.keyEncoder = keyEncoder; this.keysHash = keysHash; this.inputQueue = inputQueue; this.blockFactory = blockFactory; @@ -242,7 +248,7 @@ public Page getOutput() { @Override public void close() { - Releasables.closeExpectNoException(spare, inputQueue, output, keysHash); + Releasables.closeExpectNoException(spare, inputQueue, output, keysHash, keyEncoder); inputQueue = null; output = null; } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index 1132f6cbfa282..a037e1d8e11c7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -14,8 +14,10 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.IndexedByShardIdFromList; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.GroupKeyEncoder; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -90,7 +92,7 @@ public void testStatus() { assertThat(status, instanceOf(GroupedTopNOperatorStatus.class)); GroupedTopNOperatorStatus groupedStatus = (GroupedTopNOperatorStatus) status; assertThat(groupedStatus.occupiedRows(), equalTo(0)); - assertThat(groupedStatus.groupCount(), equalTo(0L)); + assertThat(groupedStatus.groupCount(), equalTo(0)); assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); assertThat(groupedStatus.pagesReceived(), equalTo(0)); assertThat(groupedStatus.pagesEmitted(), equalTo(0)); @@ -107,7 +109,7 @@ public void testStatus() { assertThat(groupedStatus.receiveNanos(), greaterThan(0L)); assertThat(groupedStatus.emitNanos(), equalTo(0L)); assertThat(groupedStatus.occupiedRows(), equalTo(8)); - assertThat(groupedStatus.groupCount(), equalTo(2L)); + assertThat(groupedStatus.groupCount(), equalTo(2)); assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); assertThat(groupedStatus.pagesReceived(), equalTo(1)); assertThat(groupedStatus.pagesEmitted(), equalTo(0)); @@ -379,6 +381,11 @@ public void testRandomMultipleColumns() { randomBlocksResult.encoders, uniqueOrders, groupKeys.stream().mapToInt(Integer::intValue).toArray(), + new GroupKeyEncoder( + groupKeys.stream().mapToInt(Integer::intValue).toArray(), + randomBlocksResult.elementTypes, + new BreakingBytesRefBuilder(nonBreakingBigArrays().breakerService().getBreaker("request"), "group-key-encoder") + ), rows, Long.MAX_VALUE ) @@ -442,6 +449,11 @@ private List> runGroupedTopN( encoders, sortOrders, groupKeys, + new GroupKeyEncoder( + groupKeys, + elementTypes, + new BreakingBytesRefBuilder(nonBreakingBigArrays().breakerService().getBreaker("request"), "group-key-encoder") + ), randomPageSize(), Long.MAX_VALUE ) From a7649adc6e6522e513e3610b34739470d09c8118 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 10 Mar 2026 12:48:06 +0100 Subject: [PATCH 16/22] Removed redundant groupKeys to simplify double accounting --- .../elasticsearch/compute/operator/GroupKeyEncoder.java | 5 +++++ .../compute/operator/topn/GroupedTopNOperator.java | 7 +------ .../compute/operator/topn/GroupedTopNOperatorTests.java | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java index 25c7e2480b700..d2b9216808122 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java @@ -117,9 +117,14 @@ private void encodeBlock(Block block, ElementType type, int position) { } } + public int[] groupChannels() { + return groupChannels; + } + @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(groupChannels); size += RamUsageEstimator.shallowSizeOf(elementTypes); size += scratch.ramBytesUsed(); size += RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java index f252510fe70bf..3b13ee5172544 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -76,7 +76,6 @@ public GroupedTopNOperator get(DriverContext driverContext) { elementTypes, encoders, sortOrders, - groupKeysArray, keyEncoder, maxPageSize, jumboPageBytes @@ -107,7 +106,6 @@ public String describe() { private final List elementTypes; private final List encoders; private final List sortOrders; - private final int[] groupKeys; private final boolean[] channelInKey; private final GroupKeyEncoder keyEncoder; @@ -131,7 +129,6 @@ public GroupedTopNOperator( List elementTypes, List encoders, List sortOrders, - int[] groupKeys, GroupKeyEncoder keyEncoder, int maxPageSize, long jumboPageBytes @@ -159,7 +156,6 @@ public GroupedTopNOperator( this.elementTypes = elementTypes; this.encoders = encoders; this.sortOrders = sortOrders; - this.groupKeys = groupKeys; this.channelInKey = new boolean[elementTypes.size()]; for (TopNOperator.SortOrder so : sortOrders) { channelInKey[so.channel()] = true; @@ -261,7 +257,6 @@ public long ramBytesUsed() { size += RamUsageEstimator.alignObjectSize(arrHeader + ref * elementTypes.size()); size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size()); size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size()); - size += RamUsageEstimator.sizeOf(groupKeys); size += RamUsageEstimator.sizeOf(channelInKey); size += sortOrders.size() * SORT_ORDER_SIZE; size += keyEncoder.ramBytesUsed(); @@ -303,7 +298,7 @@ public String toString() { + ", sortOrders=" + sortOrders + ", groupKeys=" - + Arrays.toString(groupKeys) + + Arrays.toString(keyEncoder.groupChannels()) + "]"; } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java index a037e1d8e11c7..772fa0af4f85f 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -380,7 +380,6 @@ public void testRandomMultipleColumns() { randomBlocksResult.elementTypes, randomBlocksResult.encoders, uniqueOrders, - groupKeys.stream().mapToInt(Integer::intValue).toArray(), new GroupKeyEncoder( groupKeys.stream().mapToInt(Integer::intValue).toArray(), randomBlocksResult.elementTypes, @@ -448,7 +447,6 @@ private List> runGroupedTopN( elementTypes, encoders, sortOrders, - groupKeys, new GroupKeyEncoder( groupKeys, elementTypes, From 48fcc857655245f2e2108484b4d3c40a57a64350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 10 Mar 2026 13:02:10 +0100 Subject: [PATCH 17/22] Added example with TopNBenchmark instead --- .../_nightly/esql/TopNBenchmark.java | 167 ++++++++++++++++-- 1 file changed, 148 insertions(+), 19 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java index afba5638230ec..d951ae27e79e5 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java @@ -22,7 +22,10 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.GroupKeyEncoder; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.topn.GroupedTopNOperator; import org.elasticsearch.compute.operator.topn.SharedMinCompetitive; import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; @@ -63,6 +66,8 @@ public class TopNBenchmark { .build(); private static final int BLOCK_LENGTH = 4 * 1024; + private static final int NUM_PAGES = 1024; + private static final int SELF_TEST_PAGES = 16; private static final String LONGS = "longs"; private static final String INTS = "ints"; @@ -83,10 +88,28 @@ public class TopNBenchmark { static void selfTest() { try { - for (String data : TopNBenchmark.class.getField("data").getAnnotationsByType(Param.class)[0].value()) { - for (String topCount : TopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value()) { - for (String sortedInput : TopNBenchmark.class.getField("sortedInput").getAnnotationsByType(Param.class)[0].value()) { - run(data, Integer.parseInt(topCount), Boolean.parseBoolean(sortedInput)); + String[] dataValues = TopNBenchmark.class.getField("data").getAnnotationsByType(Param.class)[0].value(); + String[] topCountValues = TopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value(); + String[] sortedInputValues = TopNBenchmark.class.getField("sortedInput").getAnnotationsByType(Param.class)[0].value(); + String[] groupCountValues = TopNBenchmark.class.getField("groupCount").getAnnotationsByType(Param.class)[0].value(); + String[] groupKeyTypeValues = TopNBenchmark.class.getField("groupKeyType").getAnnotationsByType(Param.class)[0].value(); + String[] groupKeyCountValues = TopNBenchmark.class.getField("groupKeyCount").getAnnotationsByType(Param.class)[0].value(); + for (String data : dataValues) { + for (String topCount : topCountValues) { + int tc = Integer.parseInt(topCount); + for (String sortedInput : sortedInputValues) { + run(data, tc, Boolean.parseBoolean(sortedInput), 0, groupKeyTypeValues[0], 1, NUM_PAGES); + } + for (String groupCount : groupCountValues) { + int gc = Integer.parseInt(groupCount); + if (gc == 0) { + continue; + } + for (String gkType : groupKeyTypeValues) { + for (String gkCount : groupKeyCountValues) { + run(data, tc, false, gc, gkType, Integer.parseInt(gkCount), SELF_TEST_PAGES); + } + } } } } @@ -123,11 +146,51 @@ static void selfTest() { @Param({ "10", "1000", "4096", "10000" }) public int topCount; - private static Operator operator(String data, int topCount, boolean sortedInput) { + @Param({ "0", "10", "100", "1000" }) + public int groupCount; + + @Param({ LONGS, BYTES_REFS }) + public String groupKeyType; + + @Param({ "1", "2" }) + public int groupKeyCount; + + private static Operator operator( + String data, + int topCount, + boolean sortedInput, + int groupCount, + String groupKeyType, + int groupKeyCount + ) { String[] dataSpec = data.split("_and_"); - List elementTypes = Arrays.stream(dataSpec).map(TopNBenchmark::elementType).toList(); - List encoders = Arrays.stream(dataSpec).map(TopNBenchmark::encoder).toList(); + List elementTypes = new ArrayList<>(Arrays.stream(dataSpec).map(TopNBenchmark::elementType).toList()); + List encoders = new ArrayList<>(Arrays.stream(dataSpec).map(TopNBenchmark::encoder).toList()); List sortOrders = IntStream.range(0, dataSpec.length).mapToObj(c -> sortOrder(c, dataSpec[c])).toList(); + + if (groupCount > 0) { + int[] groupKeys = new int[groupKeyCount]; + ElementType gkElementType = groupKeyElementType(groupKeyType); + for (int i = 0; i < groupKeyCount; i++) { + groupKeys[i] = elementTypes.size(); + elementTypes.add(gkElementType); + encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); + } + var scratch = new BreakingBytesRefBuilder(blockFactory.breaker(), "group-key-encoder"); + var keyEncoder = new GroupKeyEncoder(groupKeys, elementTypes, scratch); + return new GroupedTopNOperator( + blockFactory, + blockFactory.breaker(), + topCount, + elementTypes, + encoders, + sortOrders, + keyEncoder, + 8 * 1024, + Long.MAX_VALUE + ); + } + CircuitBreakerService breakerService = new HierarchyCircuitBreakerService( CircuitBreakerMetrics.NOOP, Settings.EMPTY, @@ -157,7 +220,7 @@ private static Operator operator(String data, int topCount, boolean sortedInput) 8 * 1024, Long.MAX_VALUE, sortedInput ? TopNOperator.InputOrdering.SORTED : TopNOperator.InputOrdering.NOT_SORTED, - minCompetitive // This is optional, but doesn't add much overhead either way + minCompetitive ); } @@ -194,14 +257,48 @@ private static TopNOperator.SortOrder sortOrder(int channel, String data) { return new TopNOperator.SortOrder(channel, ascDesc(data), false); } - private static void checkExpected(int topCount, List pages) { - if (topCount != pages.stream().mapToLong(Page::getPositionCount).sum()) { - throw new AssertionError("expected [" + topCount + "] but got [" + pages.size() + "]"); + private static ElementType groupKeyElementType(String groupKeyType) { + return switch (groupKeyType) { + case LONGS -> ElementType.LONG; + case BYTES_REFS -> ElementType.BYTES_REF; + default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); + }; + } + + private static void checkExpected(int topCount, int groupCount, int numPages, List pages) { + long actualOutput = pages.stream().mapToLong(Page::getPositionCount).sum(); + if (groupCount > 0) { + int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); + long expectedOutput = 0; + for (int g = 0; g < effectiveGroupCount; g++) { + int rowsPerPage = BLOCK_LENGTH / effectiveGroupCount + (g < BLOCK_LENGTH % effectiveGroupCount ? 1 : 0); + long totalRowsForGroup = (long) rowsPerPage * numPages; + expectedOutput += Math.min(topCount, totalRowsForGroup); + } + if (expectedOutput != actualOutput) { + throw new AssertionError("expected [" + expectedOutput + "] but got [" + actualOutput + "]"); + } + } else { + if (topCount != actualOutput) { + throw new AssertionError("expected [" + topCount + "] but got [" + pages.size() + "]"); + } } } - private static Page page(boolean sortedInput, String data) { + private static Page page(boolean sortedInput, String data, int groupCount, String groupKeyType, int groupKeyCount) { String[] dataSpec = data.split("_and_"); + if (groupCount > 0) { + int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); + int divisor = (int) Math.ceil(Math.sqrt(effectiveGroupCount)); + Block[] blocks = new Block[dataSpec.length + groupKeyCount]; + for (int i = 0; i < dataSpec.length; i++) { + blocks[i] = block(sortedInput, dataSpec[i]); + } + for (int k = 0; k < groupKeyCount; k++) { + blocks[dataSpec.length + k] = groupKeyBlock(groupKeyType, effectiveGroupCount, divisor, k, groupKeyCount); + } + return new Page(blocks); + } return new Page(Arrays.stream(dataSpec).map(d -> block(sortedInput, d)).toArray(Block[]::new)); } @@ -241,6 +338,30 @@ private static Block block(boolean sortedInput, String data) { }; } + private static Block groupKeyBlock(String groupKeyType, int effectiveGroupCount, int divisor, int keyIndex, int groupKeyCount) { + return switch (groupKeyType) { + case LONGS -> { + var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); + for (int i = 0; i < BLOCK_LENGTH; i++) { + int groupId = i % effectiveGroupCount; + long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); + builder.appendLong(keyValue); + } + yield builder.build(); + } + case BYTES_REFS -> { + BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(BLOCK_LENGTH); + for (int i = 0; i < BLOCK_LENGTH; i++) { + int groupId = i % effectiveGroupCount; + long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); + builder.appendBytesRef(new BytesRef(Long.toString(keyValue))); + } + yield builder.build(); + } + default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); + }; + } + private static > List maybeSort(boolean sortedInput, String data, Stream randomValues) { List values = new ArrayList<>(); randomValues.forEachOrdered(values::add); @@ -252,15 +373,23 @@ private static > List maybeSort(boolean sortedInput, } @Benchmark - @OperationsPerInvocation(1024 * BLOCK_LENGTH) + @OperationsPerInvocation(NUM_PAGES * BLOCK_LENGTH) public void run() { - run(data, topCount, sortedInput); + run(data, topCount, sortedInput, groupCount, groupKeyType, groupKeyCount, NUM_PAGES); } - private static void run(String data, int topCount, boolean sortedInput) { - try (Operator operator = operator(data, topCount, sortedInput)) { - Page page = page(sortedInput, data); - for (int i = 0; i < 1024; i++) { + private static void run( + String data, + int topCount, + boolean sortedInput, + int groupCount, + String groupKeyType, + int groupKeyCount, + int numPages + ) { + try (Operator operator = operator(data, topCount, sortedInput, groupCount, groupKeyType, groupKeyCount)) { + Page page = page(sortedInput, data, groupCount, groupKeyType, groupKeyCount); + for (int i = 0; i < numPages; i++) { operator.addInput(page.shallowCopy()); } operator.finish(); @@ -269,7 +398,7 @@ private static void run(String data, int topCount, boolean sortedInput) { while ((p = operator.getOutput()) != null) { results.add(p); } - checkExpected(topCount, results); + checkExpected(topCount, groupCount, numPages, results); } } } From 26e7c9a1a94bf3b9a1162f12b7e4ee8c0070e759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 10 Mar 2026 16:39:39 +0100 Subject: [PATCH 18/22] Revert "Added example with TopNBenchmark instead" This reverts commit 48fcc857655245f2e2108484b4d3c40a57a64350. --- .../_nightly/esql/TopNBenchmark.java | 167 ++---------------- 1 file changed, 19 insertions(+), 148 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java index d951ae27e79e5..afba5638230ec 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java @@ -22,10 +22,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; -import org.elasticsearch.compute.operator.GroupKeyEncoder; import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.topn.GroupedTopNOperator; import org.elasticsearch.compute.operator.topn.SharedMinCompetitive; import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; @@ -66,8 +63,6 @@ public class TopNBenchmark { .build(); private static final int BLOCK_LENGTH = 4 * 1024; - private static final int NUM_PAGES = 1024; - private static final int SELF_TEST_PAGES = 16; private static final String LONGS = "longs"; private static final String INTS = "ints"; @@ -88,28 +83,10 @@ public class TopNBenchmark { static void selfTest() { try { - String[] dataValues = TopNBenchmark.class.getField("data").getAnnotationsByType(Param.class)[0].value(); - String[] topCountValues = TopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value(); - String[] sortedInputValues = TopNBenchmark.class.getField("sortedInput").getAnnotationsByType(Param.class)[0].value(); - String[] groupCountValues = TopNBenchmark.class.getField("groupCount").getAnnotationsByType(Param.class)[0].value(); - String[] groupKeyTypeValues = TopNBenchmark.class.getField("groupKeyType").getAnnotationsByType(Param.class)[0].value(); - String[] groupKeyCountValues = TopNBenchmark.class.getField("groupKeyCount").getAnnotationsByType(Param.class)[0].value(); - for (String data : dataValues) { - for (String topCount : topCountValues) { - int tc = Integer.parseInt(topCount); - for (String sortedInput : sortedInputValues) { - run(data, tc, Boolean.parseBoolean(sortedInput), 0, groupKeyTypeValues[0], 1, NUM_PAGES); - } - for (String groupCount : groupCountValues) { - int gc = Integer.parseInt(groupCount); - if (gc == 0) { - continue; - } - for (String gkType : groupKeyTypeValues) { - for (String gkCount : groupKeyCountValues) { - run(data, tc, false, gc, gkType, Integer.parseInt(gkCount), SELF_TEST_PAGES); - } - } + for (String data : TopNBenchmark.class.getField("data").getAnnotationsByType(Param.class)[0].value()) { + for (String topCount : TopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value()) { + for (String sortedInput : TopNBenchmark.class.getField("sortedInput").getAnnotationsByType(Param.class)[0].value()) { + run(data, Integer.parseInt(topCount), Boolean.parseBoolean(sortedInput)); } } } @@ -146,51 +123,11 @@ static void selfTest() { @Param({ "10", "1000", "4096", "10000" }) public int topCount; - @Param({ "0", "10", "100", "1000" }) - public int groupCount; - - @Param({ LONGS, BYTES_REFS }) - public String groupKeyType; - - @Param({ "1", "2" }) - public int groupKeyCount; - - private static Operator operator( - String data, - int topCount, - boolean sortedInput, - int groupCount, - String groupKeyType, - int groupKeyCount - ) { + private static Operator operator(String data, int topCount, boolean sortedInput) { String[] dataSpec = data.split("_and_"); - List elementTypes = new ArrayList<>(Arrays.stream(dataSpec).map(TopNBenchmark::elementType).toList()); - List encoders = new ArrayList<>(Arrays.stream(dataSpec).map(TopNBenchmark::encoder).toList()); + List elementTypes = Arrays.stream(dataSpec).map(TopNBenchmark::elementType).toList(); + List encoders = Arrays.stream(dataSpec).map(TopNBenchmark::encoder).toList(); List sortOrders = IntStream.range(0, dataSpec.length).mapToObj(c -> sortOrder(c, dataSpec[c])).toList(); - - if (groupCount > 0) { - int[] groupKeys = new int[groupKeyCount]; - ElementType gkElementType = groupKeyElementType(groupKeyType); - for (int i = 0; i < groupKeyCount; i++) { - groupKeys[i] = elementTypes.size(); - elementTypes.add(gkElementType); - encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); - } - var scratch = new BreakingBytesRefBuilder(blockFactory.breaker(), "group-key-encoder"); - var keyEncoder = new GroupKeyEncoder(groupKeys, elementTypes, scratch); - return new GroupedTopNOperator( - blockFactory, - blockFactory.breaker(), - topCount, - elementTypes, - encoders, - sortOrders, - keyEncoder, - 8 * 1024, - Long.MAX_VALUE - ); - } - CircuitBreakerService breakerService = new HierarchyCircuitBreakerService( CircuitBreakerMetrics.NOOP, Settings.EMPTY, @@ -220,7 +157,7 @@ private static Operator operator( 8 * 1024, Long.MAX_VALUE, sortedInput ? TopNOperator.InputOrdering.SORTED : TopNOperator.InputOrdering.NOT_SORTED, - minCompetitive + minCompetitive // This is optional, but doesn't add much overhead either way ); } @@ -257,48 +194,14 @@ private static TopNOperator.SortOrder sortOrder(int channel, String data) { return new TopNOperator.SortOrder(channel, ascDesc(data), false); } - private static ElementType groupKeyElementType(String groupKeyType) { - return switch (groupKeyType) { - case LONGS -> ElementType.LONG; - case BYTES_REFS -> ElementType.BYTES_REF; - default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); - }; - } - - private static void checkExpected(int topCount, int groupCount, int numPages, List pages) { - long actualOutput = pages.stream().mapToLong(Page::getPositionCount).sum(); - if (groupCount > 0) { - int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); - long expectedOutput = 0; - for (int g = 0; g < effectiveGroupCount; g++) { - int rowsPerPage = BLOCK_LENGTH / effectiveGroupCount + (g < BLOCK_LENGTH % effectiveGroupCount ? 1 : 0); - long totalRowsForGroup = (long) rowsPerPage * numPages; - expectedOutput += Math.min(topCount, totalRowsForGroup); - } - if (expectedOutput != actualOutput) { - throw new AssertionError("expected [" + expectedOutput + "] but got [" + actualOutput + "]"); - } - } else { - if (topCount != actualOutput) { - throw new AssertionError("expected [" + topCount + "] but got [" + pages.size() + "]"); - } + private static void checkExpected(int topCount, List pages) { + if (topCount != pages.stream().mapToLong(Page::getPositionCount).sum()) { + throw new AssertionError("expected [" + topCount + "] but got [" + pages.size() + "]"); } } - private static Page page(boolean sortedInput, String data, int groupCount, String groupKeyType, int groupKeyCount) { + private static Page page(boolean sortedInput, String data) { String[] dataSpec = data.split("_and_"); - if (groupCount > 0) { - int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); - int divisor = (int) Math.ceil(Math.sqrt(effectiveGroupCount)); - Block[] blocks = new Block[dataSpec.length + groupKeyCount]; - for (int i = 0; i < dataSpec.length; i++) { - blocks[i] = block(sortedInput, dataSpec[i]); - } - for (int k = 0; k < groupKeyCount; k++) { - blocks[dataSpec.length + k] = groupKeyBlock(groupKeyType, effectiveGroupCount, divisor, k, groupKeyCount); - } - return new Page(blocks); - } return new Page(Arrays.stream(dataSpec).map(d -> block(sortedInput, d)).toArray(Block[]::new)); } @@ -338,30 +241,6 @@ private static Block block(boolean sortedInput, String data) { }; } - private static Block groupKeyBlock(String groupKeyType, int effectiveGroupCount, int divisor, int keyIndex, int groupKeyCount) { - return switch (groupKeyType) { - case LONGS -> { - var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); - for (int i = 0; i < BLOCK_LENGTH; i++) { - int groupId = i % effectiveGroupCount; - long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); - builder.appendLong(keyValue); - } - yield builder.build(); - } - case BYTES_REFS -> { - BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(BLOCK_LENGTH); - for (int i = 0; i < BLOCK_LENGTH; i++) { - int groupId = i % effectiveGroupCount; - long keyValue = groupKeyCount == 1 ? groupId : (keyIndex == 0 ? groupId / divisor : groupId % divisor); - builder.appendBytesRef(new BytesRef(Long.toString(keyValue))); - } - yield builder.build(); - } - default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); - }; - } - private static > List maybeSort(boolean sortedInput, String data, Stream randomValues) { List values = new ArrayList<>(); randomValues.forEachOrdered(values::add); @@ -373,23 +252,15 @@ private static > List maybeSort(boolean sortedInput, } @Benchmark - @OperationsPerInvocation(NUM_PAGES * BLOCK_LENGTH) + @OperationsPerInvocation(1024 * BLOCK_LENGTH) public void run() { - run(data, topCount, sortedInput, groupCount, groupKeyType, groupKeyCount, NUM_PAGES); + run(data, topCount, sortedInput); } - private static void run( - String data, - int topCount, - boolean sortedInput, - int groupCount, - String groupKeyType, - int groupKeyCount, - int numPages - ) { - try (Operator operator = operator(data, topCount, sortedInput, groupCount, groupKeyType, groupKeyCount)) { - Page page = page(sortedInput, data, groupCount, groupKeyType, groupKeyCount); - for (int i = 0; i < numPages; i++) { + private static void run(String data, int topCount, boolean sortedInput) { + try (Operator operator = operator(data, topCount, sortedInput)) { + Page page = page(sortedInput, data); + for (int i = 0; i < 1024; i++) { operator.addInput(page.shallowCopy()); } operator.finish(); @@ -398,7 +269,7 @@ private static void run( while ((p = operator.getOutput()) != null) { results.add(p); } - checkExpected(topCount, groupCount, numPages, results); + checkExpected(topCount, results); } } } From 61b22d6b38e6b37761f97c5138c74980e660cafc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 10 Mar 2026 18:07:03 +0100 Subject: [PATCH 19/22] Fixed benchmark --- .../benchmark/_nightly/esql/GroupedTopNBenchmark.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java index f8baf3f85f35a..d5abeb1ba83f2 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.GroupKeyEncoder; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.topn.GroupedTopNOperator; import org.elasticsearch.compute.operator.topn.TopNEncoder; @@ -139,7 +141,7 @@ private static Operator operator(String data, int topCount, String groupKeyType, elementTypes, encoders, sortOrders, - groupKeys, + new GroupKeyEncoder(groupKeys, elementTypes, new BreakingBytesRefBuilder(blockFactory.breaker(), "group-key-encoder")), 8 * 1024, Long.MAX_VALUE ); From bc1d185a2a05e64c026778529343d491cad7bac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 11 Mar 2026 12:32:41 +0100 Subject: [PATCH 20/22] Fixed boolean generation on topn benchmarks, and release pages --- .../_nightly/esql/GroupedTopNBenchmark.java | 15 ++++++++++----- .../benchmark/_nightly/esql/TopNBenchmark.java | 15 ++++++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java index d5abeb1ba83f2..0a1cf546cd89e 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.operator.topn.GroupedTopNOperator; import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; +import org.elasticsearch.core.Releasables; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -236,7 +237,7 @@ private static Block block(String data) { } case BOOLEANS -> { BooleanBlock.Builder builder = blockFactory.newBooleanBlockBuilder(BLOCK_LENGTH); - new Random().ints(BLOCK_LENGTH, 0, 1).forEach(i -> builder.appendBoolean(i == 1)); + new Random().ints(BLOCK_LENGTH, 0, 2).forEach(i -> builder.appendBoolean(i == 1)); yield builder.build(); } case BYTES_REFS -> { @@ -287,11 +288,15 @@ private static void run(String data, int topCount, int groupCount, String groupK } operator.finish(); List results = new ArrayList<>(); - Page p; - while ((p = operator.getOutput()) != null) { - results.add(p); + try { + Page p; + while ((p = operator.getOutput()) != null) { + results.add(p); + } + checkExpected(topCount, groupCount, numPages, results); + } finally { + Releasables.close(results); } - checkExpected(topCount, groupCount, numPages, results); } } } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java index afba5638230ec..d644abc1d1db7 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java @@ -26,6 +26,7 @@ import org.elasticsearch.compute.operator.topn.SharedMinCompetitive; import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; +import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.CircuitBreakerMetrics; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService; @@ -227,7 +228,7 @@ private static Block block(boolean sortedInput, String data) { } case BOOLEANS -> { BooleanBlock.Builder builder = blockFactory.newBooleanBlockBuilder(BLOCK_LENGTH); - maybeSort(sortedInput, data, new Random().ints(BLOCK_LENGTH, 0, 1).boxed()).forEach(i -> builder.appendBoolean(i == 1)); + maybeSort(sortedInput, data, new Random().ints(BLOCK_LENGTH, 0, 2).boxed()).forEach(i -> builder.appendBoolean(i == 1)); yield builder.build(); } case BYTES_REFS -> { @@ -265,11 +266,15 @@ private static void run(String data, int topCount, boolean sortedInput) { } operator.finish(); List results = new ArrayList<>(); - Page p; - while ((p = operator.getOutput()) != null) { - results.add(p); + try { + Page p; + while ((p = operator.getOutput()) != null) { + results.add(p); + } + checkExpected(topCount, results); + } finally { + Releasables.close(results); } - checkExpected(topCount, results); } } } From 973f12a60f2d102e8b86dfa170067b61001d937d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 12 Mar 2026 13:18:36 +0100 Subject: [PATCH 21/22] Improved group keys --- .../_nightly/esql/GroupedTopNBenchmark.java | 70 ++++++++----------- .../_nightly/esql/TopNBenchmark.java | 4 +- 2 files changed, 30 insertions(+), 44 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java index 0a1cf546cd89e..428218a11082b 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java @@ -84,19 +84,15 @@ static void selfTest() { for (String topCount : GroupedTopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value()) { for (String groupCount : GroupedTopNBenchmark.class.getField("groupCount").getAnnotationsByType(Param.class)[0] .value()) { - for (String gkType : GroupedTopNBenchmark.class.getField("groupKeyType").getAnnotationsByType(Param.class)[0] + for (String gk : GroupedTopNBenchmark.class.getField("groupKeys").getAnnotationsByType(Param.class)[0] .value()) { - for (String gkCount : GroupedTopNBenchmark.class.getField("groupKeyCount").getAnnotationsByType(Param.class)[0] - .value()) { - run( - data, - Integer.parseInt(topCount), - Integer.parseInt(groupCount), - gkType, - Integer.parseInt(gkCount), - SELF_TEST_PAGES - ); - } + run( + data, + Integer.parseInt(topCount), + Integer.parseInt(groupCount), + gk, + SELF_TEST_PAGES + ); } } } @@ -115,23 +111,20 @@ static void selfTest() { @Param({ "10", "100", "1000" }) public int groupCount; - @Param({ LONGS, BYTES_REFS }) - public String groupKeyType; + @Param({ LONGS, BYTES_REFS, LONGS + AND + LONGS, BYTES_REFS + AND + BYTES_REFS, LONGS + AND + BYTES_REFS }) + public String groupKeys; - @Param({ "1", "2" }) - public int groupKeyCount; - - private static Operator operator(String data, int topCount, String groupKeyType, int groupKeyCount) { - String[] dataSpec = data.split("_and_"); + private static Operator operator(String data, int topCount, String groupKeys) { + String[] dataSpec = data.split(AND); List elementTypes = new ArrayList<>(Arrays.stream(dataSpec).map(GroupedTopNBenchmark::elementType).toList()); List encoders = new ArrayList<>(Arrays.stream(dataSpec).map(GroupedTopNBenchmark::encoder).toList()); List sortOrders = IntStream.range(0, dataSpec.length).mapToObj(c -> sortOrder(c, dataSpec[c])).toList(); - int[] groupKeys = new int[groupKeyCount]; - ElementType gkElementType = groupKeyElementType(groupKeyType); - for (int i = 0; i < groupKeyCount; i++) { - groupKeys[i] = elementTypes.size(); - elementTypes.add(gkElementType); + String[] groupKeySpec = groupKeys.split(AND); + int[] groupKeyChannels = new int[groupKeySpec.length]; + for (int i = 0; i < groupKeySpec.length; i++) { + groupKeyChannels[i] = elementTypes.size(); + elementTypes.add(elementType(groupKeySpec[i])); encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); } @@ -142,7 +135,7 @@ private static Operator operator(String data, int topCount, String groupKeyType, elementTypes, encoders, sortOrders, - new GroupKeyEncoder(groupKeys, elementTypes, new BreakingBytesRefBuilder(blockFactory.breaker(), "group-key-encoder")), + new GroupKeyEncoder(groupKeyChannels, elementTypes, new BreakingBytesRefBuilder(blockFactory.breaker(), "group-key-encoder")), 8 * 1024, Long.MAX_VALUE ); @@ -167,14 +160,6 @@ private static TopNEncoder encoder(String data) { }; } - private static ElementType groupKeyElementType(String groupKeyType) { - return switch (groupKeyType) { - case LONGS -> ElementType.LONG; - case BYTES_REFS -> ElementType.BYTES_REF; - default -> throw new IllegalArgumentException("unsupported group key type [" + groupKeyType + "]"); - }; - } - private static boolean ascDesc(String data) { if (data.endsWith(ASC)) { return true; @@ -203,17 +188,18 @@ private static void checkExpected(int topCount, int groupCount, int numPages, Li } } - private static Page page(String data, int groupCount, String groupKeyType, int groupKeyCount) { - String[] dataSpec = data.split("_and_"); + private static Page page(String data, int groupCount, String groupKeys) { + String[] dataSpec = data.split(AND); + String[] groupKeySpec = groupKeys.split(AND); int effectiveGroupCount = Math.min(groupCount, BLOCK_LENGTH); int divisor = (int) Math.ceil(Math.sqrt(effectiveGroupCount)); - Block[] blocks = new Block[dataSpec.length + groupKeyCount]; + Block[] blocks = new Block[dataSpec.length + groupKeySpec.length]; for (int i = 0; i < dataSpec.length; i++) { blocks[i] = block(dataSpec[i]); } - for (int k = 0; k < groupKeyCount; k++) { - blocks[dataSpec.length + k] = groupKeyBlock(groupKeyType, effectiveGroupCount, divisor, k, groupKeyCount); + for (int k = 0; k < groupKeySpec.length; k++) { + blocks[dataSpec.length + k] = groupKeyBlock(groupKeySpec[k], effectiveGroupCount, divisor, k, groupKeySpec.length); } return new Page(blocks); } @@ -277,12 +263,12 @@ private static Block groupKeyBlock(String groupKeyType, int effectiveGroupCount, @Benchmark @OperationsPerInvocation(NUM_PAGES * BLOCK_LENGTH) public void run() { - run(data, topCount, groupCount, groupKeyType, groupKeyCount, NUM_PAGES); + run(data, topCount, groupCount, groupKeys, NUM_PAGES); } - private static void run(String data, int topCount, int groupCount, String groupKeyType, int groupKeyCount, int numPages) { - try (Operator operator = operator(data, topCount, groupKeyType, groupKeyCount)) { - Page page = page(data, groupCount, groupKeyType, groupKeyCount); + private static void run(String data, int topCount, int groupCount, String groupKeys, int numPages) { + try (Operator operator = operator(data, topCount, groupKeys)) { + Page page = page(data, groupCount, groupKeys); for (int i = 0; i < numPages; i++) { operator.addInput(page.shallowCopy()); } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java index d644abc1d1db7..391bc19fc5bf9 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/TopNBenchmark.java @@ -125,7 +125,7 @@ static void selfTest() { public int topCount; private static Operator operator(String data, int topCount, boolean sortedInput) { - String[] dataSpec = data.split("_and_"); + String[] dataSpec = data.split(AND); List elementTypes = Arrays.stream(dataSpec).map(TopNBenchmark::elementType).toList(); List encoders = Arrays.stream(dataSpec).map(TopNBenchmark::encoder).toList(); List sortOrders = IntStream.range(0, dataSpec.length).mapToObj(c -> sortOrder(c, dataSpec[c])).toList(); @@ -202,7 +202,7 @@ private static void checkExpected(int topCount, List pages) { } private static Page page(boolean sortedInput, String data) { - String[] dataSpec = data.split("_and_"); + String[] dataSpec = data.split(AND); return new Page(Arrays.stream(dataSpec).map(d -> block(sortedInput, d)).toArray(Block[]::new)); } From 9beae4d9320818013a231c608d0ae5eed153a528 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 12 Mar 2026 12:26:06 +0000 Subject: [PATCH 22/22] [CI] Auto commit changes from spotless --- .../benchmark/_nightly/esql/GroupedTopNBenchmark.java | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java index 428218a11082b..a5745fd8ce687 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/GroupedTopNBenchmark.java @@ -84,15 +84,8 @@ static void selfTest() { for (String topCount : GroupedTopNBenchmark.class.getField("topCount").getAnnotationsByType(Param.class)[0].value()) { for (String groupCount : GroupedTopNBenchmark.class.getField("groupCount").getAnnotationsByType(Param.class)[0] .value()) { - for (String gk : GroupedTopNBenchmark.class.getField("groupKeys").getAnnotationsByType(Param.class)[0] - .value()) { - run( - data, - Integer.parseInt(topCount), - Integer.parseInt(groupCount), - gk, - SELF_TEST_PAGES - ); + for (String gk : GroupedTopNBenchmark.class.getField("groupKeys").getAnnotationsByType(Param.class)[0].value()) { + run(data, Integer.parseInt(topCount), Integer.parseInt(groupCount), gk, SELF_TEST_PAGES); } } }