diff --git a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java index 385258b5c10c3..8618d3134b8e4 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java +++ b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java @@ -67,7 +67,7 @@ public static void closeExpectNoException(Releasable... releasables) { } /** Release the provided {@link Releasable} expecting no exception to by thrown. */ - public static void closeExpectNoException(Releasable releasable) { + public static void closeExpectNoException(@Nullable Releasable releasable) { try { close(releasable); } catch (RuntimeException e) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index e1ba99dd95f37..1016dc43d2f73 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1198,6 +1198,11 @@ public static long randomLong() { return random().nextLong(); } + /** A random long from 0..max (inclusive). */ + public static long randomLong(long max) { + return RandomNumbers.randomLongBetween(random(), 0L, max); + } + public static LongStream randomLongs() { return random().longs(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java index 121c098f61aff..d2b9216808122 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/GroupKeyEncoder.java @@ -7,7 +7,9 @@ package org.elasticsearch.compute.operator; +import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; @@ -29,7 +31,9 @@ * in block iteration order. This means {@code [1, 2]} and {@code [2, 1]} produce different keys. * Null positions are encoded as a value count of zero. */ -public class GroupKeyEncoder implements Releasable { +public class GroupKeyEncoder implements Accountable, Releasable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupKeyEncoder.class); private static final DefaultUnsortableTopNEncoder encoder = TopNEncoder.DEFAULT_UNSORTABLE; @@ -113,6 +117,20 @@ private void encodeBlock(Block block, ElementType type, int position) { } } + public int[] groupChannels() { + return groupChannels; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(groupChannels); + size += RamUsageEstimator.shallowSizeOf(elementTypes); + size += scratch.ramBytesUsed(); + size += RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); + return size; + } + @Override public void close() { scratch.close(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java new file mode 100644 index 0000000000000..5f44e948e9ce6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedQueue.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +/** + * A queue that maintains a separate {@link TopNQueue} per group, indexed by group IDs. + */ +class GroupedQueue implements Accountable, Releasable { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(GroupedQueue.class); + + private final CircuitBreaker breaker; + private final BigArrays bigArrays; + private final int topCount; + private ObjectArray queues; + + GroupedQueue(CircuitBreaker breaker, BigArrays bigArrays, int topCount) { + this.breaker = breaker; + this.bigArrays = bigArrays; + this.topCount = topCount; + this.queues = bigArrays.newObjectArray(0); + } + + @Override + public String toString() { + return size() + "/" + queues.size() + "/" + topCount; + } + + int size() { + int totalSize = 0; + for (long i = 0; i < queues.size(); i++) { + TopNQueue queue = queues.get(i); + if (queue != null) { + totalSize += queue.size(); + } + } + return totalSize; + } + + TopNQueue getOrCreateQueue(long groupId) { + if (groupId >= queues.size()) { + queues = bigArrays.grow(queues, groupId + 1); + } + TopNQueue queue = queues.get(groupId); + if (queue == null) { + queue = TopNQueue.build(breaker, topCount); + queues.set(groupId, queue); + } + return queue; + } + + /** + * Removes and returns all rows from all per-group queues. + * For an ascending order, the first element will be the min element (or last in the + * priority queue), and vice versa. + */ + List popAll() { + List allRows = new ArrayList<>(size()); + for (long i = 0; i < queues.size(); i++) { + TopNQueue queue = queues.get(i); + if (queue != null) { + queue.popAllInto(allRows); + queue.close(); + queues.set(i, null); + } + } + allRows.sort((r1, r2) -> -r1.compareTo(r2)); + return allRows; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + if (queues != null) { + total += queues.ramBytesUsed(); + for (long i = 0; i < queues.size(); i++) { + TopNQueue queue = queues.get(i); + if (queue != null) { + total += queue.ramBytesUsed(); + } + } + } + return total; + } + + @Override + public void close() { + Releasables.close(() -> { + if (queues != null) { + for (long i = 0; i < queues.size(); i++) { + TopNQueue queue = queues.get(i); + if (queue != null) { + queue.close(); + queues.set(i, null); + } + } + } + }, queues); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java new file mode 100644 index 0000000000000..3b13ee5172544 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperator.java @@ -0,0 +1,408 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BytesRefHashTable; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.blockhash.HashImplFactory; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.GroupKeyEncoder; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.core.Releasables; + +import java.util.Arrays; +import java.util.List; + +/** + * A top-N operator for grouped (SORT + LIMIT BY) queries. Maintains per-group priority queues + * using a {@link GroupKeyEncoder} to map group key columns to integer group IDs. + *

+ * Group keys use list semantics for multivalues: {@code [1,2]} and {@code [2,1]} are different groups. + *

+ * Unlike {@link TopNOperator}, this operator does not support sorted input optimization + * or {@link SharedMinCompetitive} tracking, as these optimizations are not applicable + * to grouped top-N. + */ +public class GroupedTopNOperator implements Operator, Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedTopNOperator.class) + RamUsageEstimator + .shallowSizeOfInstance(List.class) * 3; + + private static final long SORT_ORDER_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.SortOrder.class); + + public record GroupedTopNOperatorFactory( + int topCount, + List elementTypes, + List encoders, + List sortOrders, + List groupKeys, + int maxPageSize, + long jumboPageBytes + ) implements OperatorFactory { + public GroupedTopNOperatorFactory { + for (ElementType e : elementTypes) { + if (e == null) { + throw new IllegalArgumentException("ElementType not known"); + } + } + if (groupKeys.isEmpty()) { + throw new IllegalArgumentException("GroupedTopNOperator requires at least one group key"); + } + } + + @Override + public GroupedTopNOperator get(DriverContext driverContext) { + var scratch = new BreakingBytesRefBuilder(driverContext.breaker(), "group-key-encoder"); + int[] groupKeysArray = groupKeys.stream().mapToInt(Integer::intValue).toArray(); + var keyEncoder = new GroupKeyEncoder(groupKeysArray, elementTypes, scratch); + return new GroupedTopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + topCount, + elementTypes, + encoders, + sortOrders, + keyEncoder, + maxPageSize, + jumboPageBytes + ); + } + + @Override + public String describe() { + return "GroupedTopNOperator[count=" + + topCount + + ", elementTypes=" + + elementTypes + + ", encoders=" + + encoders + + ", sortOrders=" + + sortOrders + + ", groupKeys=" + + groupKeys + + "]"; + } + } + + private final BlockFactory blockFactory; + private final CircuitBreaker breaker; + private final int maxPageSize; + private final long jumboPageBytes; + private final int topCount; + private final List elementTypes; + private final List encoders; + private final List sortOrders; + private final boolean[] channelInKey; + private final GroupKeyEncoder keyEncoder; + + private BytesRefHashTable keysHash; + private GroupedQueue inputQueue; + private TopNRow spare; + + private ReleasableIterator output; + + private long receiveNanos; + private long emitNanos; + private int pagesReceived; + private int pagesEmitted; + private long rowsReceived; + private long rowsEmitted; + + public GroupedTopNOperator( + BlockFactory blockFactory, + CircuitBreaker breaker, + int topCount, + List elementTypes, + List encoders, + List sortOrders, + GroupKeyEncoder keyEncoder, + int maxPageSize, + long jumboPageBytes + ) { + BytesRefHashTable keysHash = null; + GroupedQueue inputQueue = null; + boolean success = false; + try { + keysHash = HashImplFactory.newBytesRefHash(blockFactory); + inputQueue = new GroupedQueue(breaker, blockFactory.bigArrays(), topCount); + success = true; + } finally { + if (success == false) { + Releasables.close(keyEncoder, keysHash, inputQueue); + } + } + this.keyEncoder = keyEncoder; + this.keysHash = keysHash; + this.inputQueue = inputQueue; + this.blockFactory = blockFactory; + this.breaker = breaker; + this.maxPageSize = maxPageSize; + this.jumboPageBytes = jumboPageBytes; + this.topCount = topCount; + this.elementTypes = elementTypes; + this.encoders = encoders; + this.sortOrders = sortOrders; + this.channelInKey = new boolean[elementTypes.size()]; + for (TopNOperator.SortOrder so : sortOrders) { + channelInKey[so.channel()] = true; + } + } + + @Override + public boolean needsInput() { + return output == null; + } + + @Override + public void addInput(Page page) { + long start = System.nanoTime(); + try { + if (this.topCount <= 0) { + return; + } + TopNOperator.RowFiller rowFiller = new TopNOperator.RowFiller(elementTypes, encoders, sortOrders, channelInKey, page); + for (int pos = 0; pos < page.getPositionCount(); pos++) { + BytesRef key = keyEncoder.encode(page, pos); + long hashOrd = keysHash.add(key); + long groupId = BlockHash.hashOrdToGroup(hashOrd); + processRow(rowFiller, pos, groupId); + } + } finally { + page.releaseBlocks(); + pagesReceived++; + rowsReceived += page.getPositionCount(); + receiveNanos += System.nanoTime() - start; + } + } + + private void processRow(TopNOperator.RowFiller rowFiller, int position, long groupId) { + if (spare == null) { + spare = new TopNRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); + } else { + spare.clear(); + } + rowFiller.writeKey(position, spare); + + // Write values BEFORE modifying the queue so that if writeValues throws (e.g. circuit breaker), + // spare is not left in both the queue and the spare field (which would double-close). + TopNQueue queue = inputQueue.getOrCreateQueue(groupId); + if (queue.size() < queue.topCount) { + rowFiller.writeValues(position, spare); + queue.add(spare); + spare = null; + } else if (queue.lessThan(queue.top(), spare)) { + rowFiller.writeValues(position, spare); + TopNRow nextSpare = queue.top(); + queue.updateTop(spare); + spare = nextSpare; + } + } + + @Override + public void finish() { + if (output == null) { + long start = System.nanoTime(); + output = buildResult(); + emitNanos += System.nanoTime() - start; + } + } + + @Override + public boolean isFinished() { + return output != null && output.hasNext() == false; + } + + @Override + public boolean canProduceMoreDataWithoutExtraInput() { + return output != null && output.hasNext(); + } + + @Override + public Page getOutput() { + if (output == null || output.hasNext() == false) { + return null; + } + Page ret = output.next(); + pagesEmitted++; + rowsEmitted += ret.getPositionCount(); + return ret; + } + + @Override + public void close() { + Releasables.closeExpectNoException(spare, inputQueue, output, keysHash, keyEncoder); + inputQueue = null; + output = null; + } + + @Override + public long ramBytesUsed() { + long arrHeader = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + long ref = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long size = SHALLOW_SIZE; + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * elementTypes.size()); + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size()); + size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size()); + size += RamUsageEstimator.sizeOf(channelInKey); + size += sortOrders.size() * SORT_ORDER_SIZE; + size += keyEncoder.ramBytesUsed(); + if (keysHash != null) { + size += keysHash.ramBytesUsed(); + } + if (inputQueue != null) { + size += inputQueue.ramBytesUsed(); + } + if (spare != null) { + size += spare.ramBytesUsed(); + } + return size; + } + + @Override + public Status status() { + return new GroupedTopNOperatorStatus( + receiveNanos, + emitNanos, + inputQueue != null ? inputQueue.size() : 0, + keysHash != null ? (int) keysHash.size() : 0, + ramBytesUsed(), + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } + + @Override + public String toString() { + return "GroupedTopNOperator[count=" + + inputQueue + + ", elementTypes=" + + elementTypes + + ", encoders=" + + encoders + + ", sortOrders=" + + sortOrders + + ", groupKeys=" + + Arrays.toString(keyEncoder.groupChannels()) + + "]"; + } + + /** + * Build the result iterator. Moves all rows from the {@link #inputQueue} and + * {@link #close}s it. + */ + private ReleasableIterator buildResult() { + if (spare != null) { + spare.close(); + spare = null; + } + + if (inputQueue.size() == 0) { + return ReleasableIterator.empty(); + } + + List rows = inputQueue.popAll(); + inputQueue.close(); + keysHash.close(); + inputQueue = null; + keysHash = null; + return new Result(rows); + } + + private class Result implements ReleasableIterator { + private final List rows; + private int r; + + private Result(List rows) { + this.rows = rows; + } + + @Override + public boolean hasNext() { + return r < rows.size(); + } + + @Override + public Page next() { + long start = System.nanoTime(); + int size = Math.min(maxPageSize, rows.size() - r); + if (size <= 0) { + throw new IllegalStateException("can't make empty pages. " + size + " must be > 0"); + } + ResultBuilder[] builders = new ResultBuilder[elementTypes.size()]; + try { + for (int b = 0; b < builders.length; b++) { + builders[b] = ResultBuilder.resultBuilderFor(blockFactory, elementTypes.get(b), encoders.get(b), channelInKey[b], size); + } + int rEnd = r + size; + while (r < rEnd) { + try (TopNRow row = rows.set(r++, null)) { + readKeys(builders, row.keys.bytesRefView()); + readValues(builders, row.values.bytesRefView()); + } + if (totalSize(builders) > jumboPageBytes) { + break; + } + } + + return new Page(ResultBuilder.buildAll(builders)); + } finally { + Releasables.close(builders); + emitNanos += System.nanoTime() - start; + } + } + + private long totalSize(ResultBuilder[] builders) { + long total = 0; + for (ResultBuilder b : builders) { + total += b.estimatedBytes(); + } + return total; + } + + @Override + public void close() { + Releasables.close(rows); + } + + private void readKeys(ResultBuilder[] builders, BytesRef keys) { + for (TopNOperator.SortOrder so : sortOrders) { + if (keys.bytes[keys.offset] == so.nul()) { + keys.offset++; + keys.length--; + continue; + } + keys.offset++; + keys.length--; + builders[so.channel()].decodeKey(keys, so.asc()); + } + if (keys.length != 0) { + throw new IllegalArgumentException("didn't read all keys"); + } + } + + private void readValues(ResultBuilder[] builders, BytesRef values) { + for (ResultBuilder builder : builders) { + builder.decodeValue(values); + } + if (values.length != 0) { + throw new IllegalArgumentException("didn't read all values"); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java new file mode 100644 index 0000000000000..15e524b21d0a2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatus.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class GroupedTopNOperatorStatus implements Operator.Status { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Operator.Status.class, + "groupedtopn", + GroupedTopNOperatorStatus::new + ); + + private final long receiveNanos; + private final long emitNanos; + private final int occupiedRows; + private final int groupCount; + private final long ramBytesUsed; + private final int pagesReceived; + private final int pagesEmitted; + private final long rowsReceived; + private final long rowsEmitted; + + public GroupedTopNOperatorStatus( + long receiveNanos, + long emitNanos, + int occupiedRows, + int groupCount, + long ramBytesUsed, + int pagesReceived, + int pagesEmitted, + long rowsReceived, + long rowsEmitted + ) { + this.receiveNanos = receiveNanos; + this.emitNanos = emitNanos; + this.occupiedRows = occupiedRows; + this.groupCount = groupCount; + this.ramBytesUsed = ramBytesUsed; + this.pagesReceived = pagesReceived; + this.pagesEmitted = pagesEmitted; + this.rowsReceived = rowsReceived; + this.rowsEmitted = rowsEmitted; + } + + GroupedTopNOperatorStatus(StreamInput in) throws IOException { + this.receiveNanos = in.readVLong(); + this.emitNanos = in.readVLong(); + this.occupiedRows = in.readVInt(); + this.groupCount = in.readVInt(); + this.ramBytesUsed = in.readVLong(); + + this.pagesReceived = in.readVInt(); + this.pagesEmitted = in.readVInt(); + this.rowsReceived = in.readVLong(); + this.rowsEmitted = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(receiveNanos); + out.writeVLong(emitNanos); + + out.writeVInt(occupiedRows); + out.writeVInt(groupCount); + out.writeVLong(ramBytesUsed); + + out.writeVInt(pagesReceived); + out.writeVInt(pagesEmitted); + out.writeVLong(rowsReceived); + out.writeVLong(rowsEmitted); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public long receiveNanos() { + return receiveNanos; + } + + public long emitNanos() { + return emitNanos; + } + + public int occupiedRows() { + return occupiedRows; + } + + public int groupCount() { + return groupCount; + } + + public long ramBytesUsed() { + return ramBytesUsed; + } + + public int pagesReceived() { + return pagesReceived; + } + + public int pagesEmitted() { + return pagesEmitted; + } + + public long rowsReceived() { + return rowsReceived; + } + + public long rowsEmitted() { + return rowsEmitted; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("receive_nanos", receiveNanos); + if (builder.humanReadable()) { + builder.field("receive_time", TimeValue.timeValueNanos(receiveNanos).toString()); + } + builder.field("emit_nanos", emitNanos); + if (builder.humanReadable()) { + builder.field("emit_time", TimeValue.timeValueNanos(emitNanos).toString()); + } + builder.field("occupied_rows", occupiedRows); + builder.field("group_count", groupCount); + builder.field("ram_bytes_used", ramBytesUsed); + builder.field("ram_used", ByteSizeValue.ofBytes(ramBytesUsed)); + builder.field("pages_received", pagesReceived); + builder.field("pages_emitted", pagesEmitted); + builder.field("rows_received", rowsReceived); + builder.field("rows_emitted", rowsEmitted); + return builder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + GroupedTopNOperatorStatus that = (GroupedTopNOperatorStatus) o; + return receiveNanos == that.receiveNanos + && emitNanos == that.emitNanos + && occupiedRows == that.occupiedRows + && groupCount == that.groupCount + && ramBytesUsed == that.ramBytesUsed + && pagesReceived == that.pagesReceived + && pagesEmitted == that.pagesEmitted + && rowsReceived == that.rowsReceived + && rowsEmitted == that.rowsEmitted; + } + + @Override + public int hashCode() { + return Objects.hash( + receiveNanos, + emitNanos, + occupiedRows, + groupCount, + ramBytesUsed, + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.minimumCompatible(); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java index 2fd7231b7ee36..ecf4d618af1b4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java @@ -9,27 +9,21 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Objects; /** * An operator that sorts "rows" of values by encoding the values to sort on, as bytes (using BytesRef). Each data type is encoded @@ -52,123 +46,16 @@ public enum InputOrdering { } /** - * A single top "row". Implements {@link Comparable} and {@link Row#equals} comparing - * the sort keys. + * Fills {@link TopNRow}s from page data. Handles both sort-key encoding and value + * extraction, and tracks pre-allocation sizes for key and value buffers. */ - static final class Row implements Accountable, Comparable, Releasable { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Row.class); - - private final CircuitBreaker breaker; - - /** - * The sort keys, encoded into bytes so we can sort by calling {@link Arrays#compareUnsigned}. - */ - final BreakingBytesRefBuilder keys; - - /** - * Values to reconstruct the row. Sort of. When we reconstruct the row we read - * from both the {@link #keys} and the {@link #values}. So this only contains - * what is required to reconstruct the row that isn't already stored in {@link #values}. - */ - final BreakingBytesRefBuilder values; - - /** - * Reference counter for the shard this row belongs to, used for rows containing a {@link DocVector} to ensure that the shard - * context before we build the final result. - */ - @Nullable - RefCounted shardRefCounter; - - Row(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { - breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "topn"); - this.breaker = breaker; - boolean success = false; - try { - keys = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedKeysSize); - values = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedValueSize); - success = true; - } finally { - if (success == false) { - close(); - } - } - } - - @Override - public long ramBytesUsed() { - return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); - } - - @Override - public void close() { - clearRefCounters(); - Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); - } - - public void clearRefCounters() { - if (shardRefCounter != null) { - shardRefCounter.decRef(); - } - shardRefCounter = null; - } - - void setShardRefCounted(RefCounted shardRefCounted) { - if (this.shardRefCounter != null) { - this.shardRefCounter.decRef(); - } - this.shardRefCounter = shardRefCounted; - this.shardRefCounter.mustIncRef(); - } - - @Override - public int compareTo(Row rhs) { - // TODO if we fill the trailing bytes with 0 we could do a comparison on the entire array - // When Nik measured this it was marginally faster. But it's worth a bit of research. - return -keys.bytesRefView().compareTo(rhs.keys.bytesRefView()); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) { - return false; - } - ; - Row row = (Row) o; - return keys.bytesRefView().equals(row.keys.bytesRefView()); - } - - @Override - public int hashCode() { - return Objects.hashCode(keys); - } - - @Override - public String toString() { - StringBuilder b = new StringBuilder("Row[key="); - b.append(keys.bytesRefView()); - b.append(", values="); - - if (values.length() < 100) { - b.append(values.bytesRefView()); - } else { - b.append('['); - assert values.bytesRefView().offset == 0; - for (int i = 0; i < 100; i++) { - if (i != 0) { - b.append(" "); - } - b.append(Integer.toHexString(values.bytesRefView().bytes[i] & 255)); - } - b.append("..."); - } - return b.append("]").toString(); - } - } - static final class RowFiller { private final ValueExtractor[] valueExtractors; private final KeyExtractor[] keyExtractors; + private int keyPreAllocSize = 0; + private int valuePreAllocSize = 0; + RowFiller( List elementTypes, List encoders, @@ -199,13 +86,22 @@ static final class RowFiller { } } - void writeKey(int position, Row row) { + int preAllocatedKeysSize() { + return keyPreAllocSize; + } + + int preAllocatedValueSize() { + return valuePreAllocSize; + } + + void writeKey(int position, TopNRow row) { for (KeyExtractor keyExtractor : keyExtractors) { keyExtractor.writeKey(row.keys, position); } + keyPreAllocSize = newPreAllocSize(row.keys, keyPreAllocSize); } - void writeValues(int position, Row destination) { + void writeValues(int position, TopNRow destination) { for (ValueExtractor e : valueExtractors) { var refCounted = e.getRefCountedForShard(position); if (refCounted != null) { @@ -213,6 +109,15 @@ void writeValues(int position, Row destination) { } e.writeValue(destination.values, position); } + valuePreAllocSize = newPreAllocSize(destination.values, valuePreAllocSize); + } + + /** + * Pre-allocation size heuristic: use the larger of the current builder length and half + * the previous pre-alloc size, so the size decays after a single unusually large row. + */ + private static int newPreAllocSize(BreakingBytesRefBuilder builder, int sparePreAllocSize) { + return Math.max(builder.length(), sparePreAllocSize / 2); } } @@ -308,10 +213,8 @@ public String describe() { */ private int minCompetitiveUpdates; - private Queue inputQueue; - private Row spare; - private int spareValuesPreAllocSize = 0; - private int spareKeysPreAllocSize = 0; + private TopNQueue inputQueue; + private TopNRow spare; private ReleasableIterator output; @@ -352,11 +255,11 @@ public TopNOperator( InputOrdering inputOrdering, @Nullable SharedMinCompetitive.Supplier minCompetitiveSupplier ) { - Queue inputQueue = null; + TopNQueue inputQueue = null; SharedMinCompetitive minCompetitive = null; boolean success = false; try { - inputQueue = Queue.build(breaker, topCount); + inputQueue = TopNQueue.build(breaker, topCount); minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); success = true; } finally { @@ -407,32 +310,25 @@ public void addInput(Page page) { for (int i = 0; i < page.getPositionCount(); i++) { if (spare == null) { - spare = new Row(breaker, spareKeysPreAllocSize, spareValuesPreAllocSize); + spare = new TopNRow(breaker, rowFiller.preAllocatedKeysSize(), rowFiller.preAllocatedValueSize()); } else { - spare.keys.clear(); - spare.values.clear(); - spare.clearRefCounters(); + spare.clear(); } rowFiller.writeKey(i, spare); - // When rows are very long, appending the values one by one can lead to lots of allocations. - // To avoid this, pre-allocate at least as much size as in the last seen row. - // Let the pre-allocation size decay in case we only have 1 huge row and smaller rows otherwise. - spareKeysPreAllocSize = Math.max(spare.keys.length(), spareKeysPreAllocSize / 2); - - // This is `inputQueue.insertWithOverflow` with followed by filling in the value only if we inserted. + // This is `inputQueue.insertWithOverflow` followed by filling in the value only if we inserted. + // We must write values BEFORE modifying the queue so that if writeValues throws (e.g. circuit + // breaker), spare is not left in both the queue and the spare field (which would double-close). if (inputQueue.size() < inputQueue.topCount) { // Heap not yet full, just add elements rowFiller.writeValues(i, spare); - spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2); inputQueue.add(spare); spare = null; modified = true; } else if (inputQueue.lessThan(inputQueue.top(), spare)) { - // Heap full AND this node fit in it. - Row nextSpare = inputQueue.top(); + // Heap full AND this node fits in it. + TopNRow nextSpare = inputQueue.top(); rowFiller.writeValues(i, spare); - spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2); inputQueue.updateTop(spare); spare = nextSpare; modified = true; @@ -579,80 +475,6 @@ public String toString() { + "]"; } - private static class Queue extends PriorityQueue implements Accountable, Releasable { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Queue.class); - private final CircuitBreaker breaker; - private final int topCount; - - /** - * Track memory usage in the breaker then build the {@link Queue}. - */ - static Queue build(CircuitBreaker breaker, int topCount) { - breaker.addEstimateBytesAndMaybeBreak(Queue.sizeOf(topCount), "esql engine topn"); - return new Queue(breaker, topCount); - } - - private Queue(CircuitBreaker breaker, int topCount) { - super(topCount); - this.breaker = breaker; - this.topCount = topCount; - } - - @Override - protected boolean lessThan(Row lhs, Row rhs) { - return lhs.compareTo(rhs) < 0; - } - - @Override - public String toString() { - return size() + "/" + topCount; - } - - @Override - public long ramBytesUsed() { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) - ); - for (Row r : this) { - total += r == null ? 0 : r.ramBytesUsed(); - } - return total; - } - - @Override - public void close() { - Releasables.close( - /* - * Release all entries in the topn, nulling references to each row after closing them - * so they can be GC immediately. Without this nulling very large heaps can race with - * the circuit breaker itself. With this we're still racing, but we're only racing a - * single row at a time. And single rows can only be so large. And we have enough slop - * to live with being inaccurate by one row. - */ - () -> { - for (int i = 0; i < getHeapArray().length; i++) { - Row row = (Row) getHeapArray()[i]; - if (row != null) { - row.close(); - getHeapArray()[i] = null; - } - } - }, - // Release the array itself - () -> breaker.addWithoutBreaking(-Queue.sizeOf(topCount)) - ); - } - - private static long sizeOf(int topCount) { - long total = SHALLOW_SIZE; - total += RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) - ); - return total; - } - } - /** * Build the result iterator. Moves all rows from the {@link #inputQueue} and * {@link #close}s it. @@ -668,10 +490,8 @@ private ReleasableIterator buildResult() { return ReleasableIterator.empty(); } - List rows = new ArrayList<>(inputQueue.size()); - while (inputQueue.size() > 0) { - rows.add(inputQueue.pop()); - } + List rows = new ArrayList<>(inputQueue.size()); + inputQueue.popAllInto(rows); Collections.reverse(rows); inputQueue.close(); inputQueue = null; @@ -679,10 +499,10 @@ private ReleasableIterator buildResult() { } private class Result implements ReleasableIterator { - private final List rows; + private final List rows; private int r; - private Result(List rows) { + private Result(List rows) { this.rows = rows; } @@ -705,7 +525,7 @@ public Page next() { } int rEnd = r + size; while (r < rEnd) { - try (Row row = rows.set(r++, null)) { + try (TopNRow row = rows.set(r++, null)) { readKeys(builders, row.keys.bytesRefView()); readValues(builders, row.values.bytesRefView()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java new file mode 100644 index 0000000000000..8bcda7d16e2c8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNQueue.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +/** + * A bounded min-heap of {@link TopNRow}s used to find the top-N rows by sort key. + * Used both by {@link TopNOperator} (one global queue) and by {@link GroupedQueue} + * (one queue per group). + */ +class TopNQueue extends PriorityQueue implements Accountable, Releasable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNQueue.class); + + private final CircuitBreaker breaker; + final int topCount; + + /** + * Track memory usage in the breaker then build the {@link TopNQueue}. + */ + static TopNQueue build(CircuitBreaker breaker, int topCount) { + breaker.addEstimateBytesAndMaybeBreak(sizeOf(topCount), "esql engine topn"); + return new TopNQueue(breaker, topCount); + } + + private TopNQueue(CircuitBreaker breaker, int topCount) { + super(topCount); + this.breaker = breaker; + this.topCount = topCount; + } + + @Override + protected boolean lessThan(TopNRow lhs, TopNRow rhs) { + return lhs.compareTo(rhs) < 0; + } + + /** + * Attempts to insert a row into the queue. + * @return {@code null} if the row was inserted into a non-full queue; + * the evicted row if the row replaced the current top; + * the input row itself if it was rejected (worse than all in the queue). + */ + TopNRow addRow(TopNRow row) { + if (size() < topCount) { + add(row); + return null; + } else if (lessThan(top(), row)) { + TopNRow evicted = top(); + updateTop(row); + return evicted; + } + return row; + } + + /** + * Drains all rows from this queue into the given list. + */ + void popAllInto(List target) { + while (size() > 0) { + target.add(pop()); + } + } + + @Override + public String toString() { + return size() + "/" + topCount; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) + ); + for (TopNRow r : this) { + total += r == null ? 0 : r.ramBytesUsed(); + } + return total; + } + + @Override + public void close() { + Releasables.close( + /* + * Release all entries in the topn, nulling references to each row after closing them + * so they can be GC immediately. Without this nulling very large heaps can race with + * the circuit breaker itself. With this we're still racing, but we're only racing a + * single row at a time. And single rows can only be so large. And we have enough slop + * to live with being inaccurate by one row. + */ + () -> { + for (int i = 0; i < getHeapArray().length; i++) { + TopNRow row = (TopNRow) getHeapArray()[i]; + if (row != null) { + row.close(); + getHeapArray()[i] = null; + } + } + }, + () -> breaker.addWithoutBreaking(-sizeOf(topCount)) + ); + } + + static long sizeOf(int topCount) { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.alignObjectSize( + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1) + ); + return total; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java new file mode 100644 index 0000000000000..951d85145ed16 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNRow.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.Arrays; + +/** + * A single row in a top-N operation. Stores encoded sort keys and values. + * Implements {@link Comparable} and {@link #equals} comparing the sort keys. + */ +final class TopNRow implements Accountable, Comparable, Releasable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNRow.class); + + private final CircuitBreaker breaker; + + /** + * The sort keys, encoded into bytes so we can sort by calling {@link Arrays#compareUnsigned}. + */ + final BreakingBytesRefBuilder keys; + + /** + * Values to reconstruct the row. When we reconstruct the row we read + * from both the {@link #keys} and the {@link #values}. So this only contains + * what is required to reconstruct the row that isn't already stored in {@link #keys}. + */ + final BreakingBytesRefBuilder values; + + /** + * Reference counter for the shard this row belongs to, used for rows containing a + * DocVector to ensure the shard context lives until we build the final result. + */ + @Nullable + RefCounted shardRefCounter; + + TopNRow(CircuitBreaker breaker, int preAllocatedKeysSize, int preAllocatedValueSize) { + breaker.addEstimateBytesAndMaybeBreak(SHALLOW_SIZE, "topn"); + this.breaker = breaker; + boolean success = false; + try { + keys = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedKeysSize); + values = new BreakingBytesRefBuilder(breaker, "topn", preAllocatedValueSize); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + keys.ramBytesUsed() + values.ramBytesUsed(); + } + + @Override + public void close() { + clearRefCounters(); + Releasables.closeExpectNoException(() -> breaker.addWithoutBreaking(-SHALLOW_SIZE), keys, values); + } + + void clear() { + keys.clear(); + values.clear(); + clearRefCounters(); + } + + void clearRefCounters() { + if (shardRefCounter != null) { + shardRefCounter.decRef(); + } + shardRefCounter = null; + } + + void setShardRefCounted(RefCounted shardRefCounted) { + if (this.shardRefCounter != null) { + this.shardRefCounter.decRef(); + } + this.shardRefCounter = shardRefCounted; + this.shardRefCounter.mustIncRef(); + } + + @Override + public int compareTo(TopNRow rhs) { + // TODO if we fill the trailing bytes with 0 we could do a comparison on the entire array + // When Nik measured this it was marginally faster. But it's worth a bit of research. + return -keys.bytesRefView().compareTo(rhs.keys.bytesRefView()); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + TopNRow row = (TopNRow) o; + return keys.bytesRefView().equals(row.keys.bytesRefView()); + } + + @Override + public int hashCode() { + return keys.bytesRefView().hashCode(); + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("TopNRow[key="); + b.append(keys.bytesRefView()); + b.append(", values="); + + if (values.length() < 100) { + b.append(values.bytesRefView()); + } else { + b.append('['); + assert values.bytesRefView().offset == 0; + for (int i = 0; i < 100; i++) { + if (i != 0) { + b.append(" "); + } + b.append(Integer.toHexString(values.bytesRefView().bytes[i] & 255)); + } + b.append("..."); + } + return b.append("]").toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java new file mode 100644 index 0000000000000..e08f817a9c4fc --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedQueueTests.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class GroupedQueueTests extends ESTestCase { + private final CircuitBreakerService breakerService = newLimitedBreakerService(ByteSizeValue.ofMb(1)); + private final CircuitBreaker breaker = breakerService.getBreaker(CircuitBreaker.REQUEST); + private final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, breakerService); + private final BlockFactory blockFactory = new BlockFactory(breaker, bigArrays); + + public void testGroupIsolation() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 2)) { + addRows(queue, 0, 10, 30, 5); + addRows(queue, 1, 20, 40, 15); + assertThat(queue.size(), equalTo(4)); + assertQueueContents(queue, List.of(5, 10, 15, 20)); + } + } + + public void testRamBytesUsed() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + long emptySize = queue.ramBytesUsed(); + assertThat("ramBytesUsed should be positive", emptySize, greaterThan(0L)); + + addRows(queue, 0, 10, 20, 30); + long oneGroupSize = queue.ramBytesUsed(); + assertThat("RAM should grow with first group", oneGroupSize, greaterThan(emptySize)); + + addRows(queue, 1, 10, 20, 30); + assertThat("RAM should grow with more groups", queue.ramBytesUsed(), greaterThan(oneGroupSize)); + } + } + + public void testPopAllSortedBySortKey() { + try (GroupedQueue queue = new GroupedQueue(breaker, bigArrays, 5)) { + addRows(queue, 0, 30, 10, 50); + addRows(queue, 1, 20, 40); + addRows(queue, 2, 15, 25, 35); + assertQueueContents(queue, List.of(10, 15, 20, 25, 30, 35, 40, 50)); + } + } + + private TopNRow createRow(CircuitBreaker breaker, int sortKey) { + IntBlock groupKeyBlock = blockFactory.newIntBlockBuilder(1).appendInt(0).build(); + IntBlock keyBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey).build(); + IntBlock valueBlock = blockFactory.newIntBlockBuilder(1).appendInt(sortKey * 2).build(); + TopNRow row = new TopNRow(breaker, 32, 64); + var filler = new TopNOperator.RowFiller( + List.of(ElementType.INT, ElementType.INT, ElementType.INT), + List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_UNSORTABLE), + SORT_ORDERS, + new boolean[] { false, true, false }, + new Page(groupKeyBlock, keyBlock, valueBlock) + ); + try { + filler.writeKey(0, row); + filler.writeValues(0, row); + } finally { + Releasables.close(groupKeyBlock, keyBlock, valueBlock); + } + return row; + } + + private void addRows(GroupedQueue queue, int groupKey, int... values) { + for (int value : values) { + addRow(queue, groupKey, value); + } + } + + private void addRow(GroupedQueue queue, int groupKey, int value) { + TopNRow row = createRow(breaker, value); + Releasables.close(queue.getOrCreateQueue(groupKey).addRow(row)); + } + + private static final List SORT_ORDERS = List.of(new TopNOperator.SortOrder(1, true, false)); + + private void assertQueueContents(GroupedQueue queue, List expectedSortKeys) { + assertThat(queue.size(), equalTo(expectedSortKeys.size())); + List actual = queue.popAll(); + for (int i = 0; i < expectedSortKeys.size(); i++) { + int sortKey = expectedSortKeys.get(i); + assertRowValues(actual.get(i), sortKey, sortKey * 2); + } + Releasables.close(actual); + } + + private static void assertRowValues(TopNRow row, int expectedSortKey, int expectedValue) { + BytesRef keys = row.keys.bytesRefView(); + assertThat( + TopNEncoder.DEFAULT_SORTABLE.decodeInt(new BytesRef(keys.bytes, keys.offset + 1, keys.length - 1)), + equalTo(expectedSortKey) + ); + + BytesRef values = row.values.bytesRefView(); + BytesRef reader = new BytesRef(values.bytes, values.offset, values.length); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeVInt(reader), equalTo(1)); + assertThat(TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(reader), equalTo(expectedValue)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java new file mode 100644 index 0000000000000..470061010b24d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorStatusTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class GroupedTopNOperatorStatusTests extends AbstractWireSerializingTestCase { + public static GroupedTopNOperatorStatus simple() { + return new GroupedTopNOperatorStatus(100, 40, 10, 5, 2000, 123, 123, 111, 222); + } + + public static String simpleToJson() { + return """ + { + "receive_nanos" : 100, + "receive_time" : "100nanos", + "emit_nanos" : 40, + "emit_time" : "40nanos", + "occupied_rows" : 10, + "group_count" : 5, + "ram_bytes_used" : 2000, + "ram_used" : "1.9kb", + "pages_received" : 123, + "pages_emitted" : 123, + "rows_received" : 111, + "rows_emitted" : 222 + }"""; + } + + public void testToXContent() { + assertThat(Strings.toString(simple(), true, true), equalTo(simpleToJson())); + } + + @Override + protected Writeable.Reader instanceReader() { + return GroupedTopNOperatorStatus::new; + } + + @Override + protected GroupedTopNOperatorStatus createTestInstance() { + return new GroupedTopNOperatorStatus( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeInt(), + randomNonNegativeInt(), + randomNonNegativeLong(), + randomNonNegativeInt(), + randomNonNegativeInt(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected GroupedTopNOperatorStatus mutateInstance(GroupedTopNOperatorStatus instance) { + long receiveNanos = instance.receiveNanos(); + long emitNanos = instance.emitNanos(); + int occupiedRows = instance.occupiedRows(); + int groupCount = instance.groupCount(); + long ramBytesUsed = instance.ramBytesUsed(); + int pagesReceived = instance.pagesReceived(); + int pagesEmitted = instance.pagesEmitted(); + long rowsReceived = instance.rowsReceived(); + long rowsEmitted = instance.rowsEmitted(); + switch (between(0, 8)) { + case 0: + receiveNanos = randomValueOtherThan(receiveNanos, ESTestCase::randomNonNegativeLong); + break; + case 1: + emitNanos = randomValueOtherThan(emitNanos, ESTestCase::randomNonNegativeLong); + break; + case 2: + occupiedRows = randomValueOtherThan(occupiedRows, ESTestCase::randomNonNegativeInt); + break; + case 3: + groupCount = randomValueOtherThan(groupCount, ESTestCase::randomNonNegativeInt); + break; + case 4: + ramBytesUsed = randomValueOtherThan(ramBytesUsed, ESTestCase::randomNonNegativeLong); + break; + case 5: + pagesReceived = randomValueOtherThan(pagesReceived, ESTestCase::randomNonNegativeInt); + break; + case 6: + pagesEmitted = randomValueOtherThan(pagesEmitted, ESTestCase::randomNonNegativeInt); + break; + case 7: + rowsReceived = randomValueOtherThan(rowsReceived, ESTestCase::randomNonNegativeLong); + break; + case 8: + rowsEmitted = randomValueOtherThan(rowsEmitted, ESTestCase::randomNonNegativeLong); + break; + default: + throw new IllegalArgumentException(); + } + return new GroupedTopNOperatorStatus( + receiveNanos, + emitNanos, + occupiedRows, + groupCount, + ramBytesUsed, + pagesReceived, + pagesEmitted, + rowsReceived, + rowsEmitted + ); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java new file mode 100644 index 0000000000000..772fa0af4f85f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/GroupedTopNOperatorTests.java @@ -0,0 +1,519 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.topn; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.IndexedByShardIdFromList; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.GroupKeyEncoder; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.topn.TopNOperator.SortOrder; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.compute.test.TestBlockBuilder; +import org.elasticsearch.compute.test.TestDriverFactory; +import org.elasticsearch.compute.test.TestDriverRunner; +import org.elasticsearch.compute.test.operator.blocksource.ListRowsBlockSourceOperator; +import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import static org.elasticsearch.compute.data.ElementType.DOC; +import static org.elasticsearch.compute.data.ElementType.LONG; +import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_UNSORTABLE; +import static org.elasticsearch.compute.test.BlockTestUtils.readInto; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class GroupedTopNOperatorTests extends TopNOperatorTests { + private static final int TOP_COUNT = 4; + + @Override + protected int[] groupKeys() { + return new int[] { 0 }; + } + + @Override + protected void testRandomTopN(boolean asc, DriverContext context) { + int limit = randomIntBetween(1, 20); + List> inputValues = randomList(0, 1000, () -> Tuple.tuple(ESTestCase.randomLong(), ESTestCase.randomLong(9))); + List> expectedValues = computeTopN(inputValues, limit, asc); + List> outputValues = topNTwoLongColumns( + context, + inputValues, + limit, + List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(0, asc, false)), + new int[] { 1 } + ); + + assertThat(outputValues, equalTo(expectedValues)); + } + + @Override + protected List> expectedTopRowOriented(List> rowOriented, List sortOrders, int topCount) { + return computeTopN(rowOriented, IntStream.of(groupKeys()).boxed().toList(), sortOrders, topCount); + } + + @Override + public void testStatus() { + BlockFactory blockFactory = driverContext().blockFactory(); + try (Operator op = simple(SimpleOptions.DEFAULT).get(driverContext())) { + Operator.Status status = op.status(); + assertThat(status, instanceOf(GroupedTopNOperatorStatus.class)); + GroupedTopNOperatorStatus groupedStatus = (GroupedTopNOperatorStatus) status; + assertThat(groupedStatus.occupiedRows(), equalTo(0)); + assertThat(groupedStatus.groupCount(), equalTo(0)); + assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(groupedStatus.pagesReceived(), equalTo(0)); + assertThat(groupedStatus.pagesEmitted(), equalTo(0)); + assertThat(groupedStatus.rowsReceived(), equalTo(0L)); + assertThat(groupedStatus.rowsEmitted(), equalTo(0L)); + + Page p = new Page( + blockFactory.newConstantLongBlockWith(1, 10), + blockFactory.newLongArrayVector(new long[] { 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L }, 10).asBlock() + ); + op.addInput(p); + status = op.status(); + groupedStatus = (GroupedTopNOperatorStatus) status; + assertThat(groupedStatus.receiveNanos(), greaterThan(0L)); + assertThat(groupedStatus.emitNanos(), equalTo(0L)); + assertThat(groupedStatus.occupiedRows(), equalTo(8)); + assertThat(groupedStatus.groupCount(), equalTo(2)); + assertThat(groupedStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(groupedStatus.pagesReceived(), equalTo(1)); + assertThat(groupedStatus.pagesEmitted(), equalTo(0)); + assertThat(groupedStatus.rowsReceived(), equalTo(10L)); + assertThat(groupedStatus.rowsEmitted(), equalTo(0L)); + } + } + + public void testBasicTopN() { + List values = Arrays.asList(2L, 1L, 4L, null, 4L, null); + assertThat(topNLong(values, 1, true, false), equalTo(Arrays.asList(1L, 2L, 4L, null))); + assertThat(topNLong(values, 1, false, false), equalTo(Arrays.asList(4L, 2L, 1L, null))); + assertThat(topNLong(values, 1, true, true), equalTo(Arrays.asList(null, 1L, 2L, 4L))); + assertThat(topNLong(values, 1, false, true), equalTo(Arrays.asList(null, 4L, 2L, 1L))); + assertThat(topNLong(values, 2, true, false), equalTo(Arrays.asList(1L, 2L, 4L, 4L, null, null))); + assertThat(topNLong(values, 2, false, false), equalTo(Arrays.asList(4L, 4L, 2L, 1L, null, null))); + assertThat(topNLong(values, 2, true, true), equalTo(Arrays.asList(null, null, 1L, 2L, 4L, 4L))); + assertThat(topNLong(values, 2, false, true), equalTo(Arrays.asList(null, null, 4L, 4L, 2L, 1L))); + } + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new TupleLongLongBlockSourceOperator( + blockFactory, + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(ESTestCase.randomLong(), ESTestCase.randomLong(9))), + between(1, size * 2) + ); + } + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new GroupedTopNOperator.GroupedTopNOperatorFactory( + TOP_COUNT, + List.of(LONG, LONG), + List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(0, true, false)), + List.of(1), + pageSize, + Long.MAX_VALUE + ); + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return equalTo( + "GroupedTopNOperator[count=4, elementTypes=[LONG, LONG], encoders=[DefaultUnsortable, DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], groupKeys=[1]]" + ); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo( + "GroupedTopNOperator[count=0/0/4, elementTypes=[LONG, LONG], encoders=[DefaultUnsortable, DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], groupKeys=[1]]" + ); + } + + @Override + protected void assertSimpleOutput(List input, List results) { + for (int i = 0; i < results.size() - 1; i++) { + assertThat(results.get(i).getPositionCount(), equalTo(pageSize)); + } + assertThat(results.get(results.size() - 1).getPositionCount(), lessThanOrEqualTo(pageSize)); + List> values = input.stream() + .flatMap( + page -> IntStream.range(0, page.getPositionCount()) + .filter(p -> false == page.getBlock(0).isNull(p)) + .mapToObj(p -> Tuple.tuple(((LongBlock) page.getBlock(0)).getLong(p), ((LongBlock) page.getBlock(1)).getLong(p))) + + ) + .toList(); + var expected = computeTopN(values, TOP_COUNT, true); + assertThat( + results.stream() + .flatMap( + page -> IntStream.range(0, page.getPositionCount()) + .mapToObj(i -> Tuple.tuple(page.getBlock(0).getLong(i), page.getBlock(1).getLong(i))) + ) + .toList(), + equalTo(expected) + ); + } + + /** + * Tests that the SORTED input ordering optimization short-circuiting addInput() doesn't incorrectly skip rows + * belonging to groups not yet populated when another group's row is rejected. + * + *

{@code
+     * Scenario (SORT ASC, LIMIT 1 BY group):
+     * - Page 1: (group=0), (group=1) → both groups populated
+     * - Page 2: (group=0), (group=2) → (group=0) rejected from full group 0 → break skips (group=2), which was empty
+     * Expected groups: {0, 1, 2}
+     * Bug result: {0, 1} (group 2 missing)
+     * }
+ */ + public void testSortedInputWithMultipleGroups() { + List elementTypes = List.of(ElementType.INT, ElementType.INT); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, true)); + + BlockFactory bf = driverContext().blockFactory(); + Page page1 = new Page(BlockUtils.fromList(bf, List.of(List.of(1, 0), List.of(3, 1)))); + Page page2 = new Page(BlockUtils.fromList(bf, List.of(List.of(2, 0), List.of(4, 2)))); + + List> actual = runGroupedTopN(List.of(page1, page2), 1, elementTypes, encoders, sortOrders, new int[] { 1 }); + + // 3 groups, each with 1 value, ordered ASC: [1, 3, 4] + assertThat(actual.get(0), equalTo(List.of(1, 3, 4))); + assertThat(actual.get(1), equalTo(List.of(0, 1, 2))); + } + + public void testMultivalueGroupKey() { + BlockFactory blockFactory = driverContext().blockFactory(); + List elementTypes = List.of(LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page( + BlockUtils.fromList(blockFactory, List.of(List.of(10L, 100L, List.of(1L, 2L)), List.of(20L, 200L, 1L), List.of(30L, 300L, 3L))) + ); + + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2 }); + + // List semantics: [1,2] is one group, 1 is another, 3 is another. Sorted ASC by sort key. + assertThat(actual.get(0), equalTo(List.of(10L, 20L, 30L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(100L, 200L, 300L))); // Value + assertThat(actual.get(2), equalTo(List.of(List.of(1L, 2L), 1L, 3L))); // Group key (MV preserved) + } + + public void testMultivalueGroupKeyDuplicateWinner() { + BlockFactory bf = driverContext().blockFactory(); + List elementTypes = List.of(LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page(BlockUtils.fromList(bf, List.of(List.of(5L, 50L, List.of(1L, 2L)), List.of(10L, 100L, 1L)))); + + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2 }); + + // List semantics: [1,2] is one group, 1 is another. Sorted ASC by sort key. + assertThat(actual.get(0), equalTo(List.of(5L, 10L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(50L, 100L))); // Value + assertThat(actual.get(2), equalTo(List.of(List.of(1L, 2L), 1L))); // Group key (MV preserved) + } + + /** + * Tests list semantics with two multivalue group keys: a row with group_key1=[1, 2] + * and group_key2=[10, 20] belongs to exactly one group keyed by ([1,2], [10,20]). + * No cartesian product expansion occurs. + */ + public void testMultipleMultivalueGroupKeys() { + BlockFactory bf = driverContext().blockFactory(); + List elementTypes = List.of(LONG, LONG, LONG, LONG); + List encoders = List.of(TopNEncoder.DEFAULT_SORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE); + List sortOrders = List.of(new SortOrder(0, true, false)); + + Page page = new Page( + BlockUtils.fromList( + bf, + List.of(List.of(10L, 100L, List.of(1L, 2L), List.of(10L, 20L)), List.of(5L, 50L, 1L, 10L), List.of(15L, 150L, 2L, 20L)) + ) + ); + + List> actual = runGroupedTopN(List.of(page), 1, elementTypes, encoders, sortOrders, new int[] { 2, 3 }); + + // List semantics: 3 distinct groups, each with 1 row, sorted ASC by sort key + assertThat(actual.get(0), equalTo(List.of(5L, 10L, 15L))); // Sort key + assertThat(actual.get(1), equalTo(List.of(50L, 100L, 150L))); // Value + assertThat(actual.get(2), equalTo(List.of(1L, List.of(1L, 2L), 2L))); // Group key 1 (MV preserved) + assertThat(actual.get(3), equalTo(List.of(10L, List.of(10L, 20L), 20L))); // Group key 2 (MV preserved) + } + + public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { + topNShardContextManagementAux(2, List.of(true, true, true, true)); + } + + public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { + topNShardContextManagementAux(1, List.of(true, false, false, true)); + } + + private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { + List> values = Arrays.asList( + Arrays.asList(new BlockUtils.Doc(0, 10, 100), 1L, 1L), + Arrays.asList(new BlockUtils.Doc(1, 20, 200), 2L, 2L), + Arrays.asList(new BlockUtils.Doc(2, 30, 300), null, 1L), + Arrays.asList(new BlockUtils.Doc(3, 40, 400), -3L, 2L) + ); + + List refCountedList = List.of( + new SimpleRefCounted(), + new SimpleRefCounted(), + new SimpleRefCounted(), + new SimpleRefCounted() + ); + var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); + var pages = topNMultipleColumns( + driverContext(), + new ListRowsBlockSourceOperator(driverContext().blockFactory(), List.of(DOC, LONG, LONG), values) { + @Override + protected TestBlockBuilder getTestBlockBuilder(int b) { + return b == 0 ? new TestBlockBuilder.DocBlockBuilder(blockFactory, shardRefCounters) : super.getTestBlockBuilder(b); + } + }, + limit, + List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), + List.of(new SortOrder(1, true, false)), + new int[] { 2 } + ); + try { + refCountedList.forEach(RefCounted::decRef); + + assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); + List> valuesAsObjects = values.stream().map(row -> row.stream().map(v -> (Object) v).toList()).toList(); + List> actual = new ArrayList<>(); + for (Page p : pages) { + actual.addAll(readAsRowsSingleValue(p)); + } + assertThat(actual, equalTo(computeTopN(valuesAsObjects, List.of(2), List.of(new SortOrder(1, true, false)), limit))); + } finally { + Releasables.close(pages); + } + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } + } + + public void testRandomMultipleColumns() { + DriverContext driverContext = driverContext(); + int rows = randomIntBetween(50, 100); + int topCount = randomIntBetween(1, 10); + int blocksCount = randomIntBetween(10, 20); + int sortingByColumns = randomIntBetween(2, 3); + int groupKeysCount = randomIntBetween(2, 3); + + RandomBlocksResult randomBlocksResult = generateRandomSingleValueBlocks(rows, blocksCount, driverContext); + + List sortColumns = new ArrayList<>(); + for (int i = 0; i < sortingByColumns; i++) { + sortColumns.add( + randomValueOtherThanMany( + c -> randomBlocksResult.validSortKeys[c] == false || sortColumns.contains(c), + () -> randomIntBetween(0, blocksCount - 1) + ) + ); + } + + List groupKeys = new ArrayList<>(); + for (int i = 0; i < groupKeysCount; i++) { + groupKeys.add( + randomValueOtherThanMany( + c -> sortColumns.contains(c) || groupKeys.contains(c) || randomBlocksResult.validSortKeys[c] == false, + () -> randomIntBetween(0, blocksCount - 1) + ) + ); + } + + List uniqueOrders = sortColumns.stream().map(column -> new SortOrder(column, randomBoolean(), randomBoolean())).toList(); + + List results = new TestDriverRunner().builder(driverContext) + .input(List.of(new Page(randomBlocksResult.blocks.toArray(Block[]::new))).iterator()) + .run( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + randomBlocksResult.elementTypes, + randomBlocksResult.encoders, + uniqueOrders, + new GroupKeyEncoder( + groupKeys.stream().mapToInt(Integer::intValue).toArray(), + randomBlocksResult.elementTypes, + new BreakingBytesRefBuilder(nonBreakingBigArrays().breakerService().getBreaker("request"), "group-key-encoder") + ), + rows, + Long.MAX_VALUE + ) + ); + List> actualValues = new ArrayList<>(); + for (Page p : results) { + actualValues.addAll(readAsRowsSingleValue(p)); + p.releaseBlocks(); + } + + List> topNExpectedValues = computeTopN(randomBlocksResult.expectedValues, groupKeys, uniqueOrders, topCount); + + // We verify the output we got from the operator is sorted, but since we're asserting the results, we also need to handle ties which + // the operator doesn't actually guarantee any order for. + Comparator> sortOrderComparator = comparatorFromSortOrders(uniqueOrders); + assertThat(isSorted(actualValues, sortOrderComparator), equalTo(true)); + + // Given that sorting on repeated sort keys (Specially booleans, nulls...) may include arbitrary rows, + // we'll only assert the expected groups and keys, by including them in this signature + Function, List> signature = row -> { + List sig = new ArrayList<>(); + for (int gk : groupKeys.stream().mapToInt(Integer::intValue).toArray()) { + sig.add(row.get(gk)); + } + for (SortOrder so : uniqueOrders) { + sig.add(row.get(so.channel())); + } + return sig; + }; + + Map, Long> actualCounts = actualValues.stream() + .map(signature) + .collect(Collectors.groupingBy(s -> s, Collectors.counting())); + Map, Long> expectedCounts = topNExpectedValues.stream() + .map(signature) + .collect(Collectors.groupingBy(s -> s, Collectors.counting())); + + assertThat(actualCounts, equalTo(expectedCounts)); + } + + private List> runGroupedTopN( + List pages, + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys + ) { + DriverContext driverContext = driverContext(); + List> actual = new ArrayList<>(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + new CannedSourceOperator(pages.iterator()), + List.of( + new GroupedTopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + new GroupKeyEncoder( + groupKeys, + elementTypes, + new BreakingBytesRefBuilder(nonBreakingBigArrays().breakerService().getBreaker("request"), "group-key-encoder") + ), + randomPageSize(), + Long.MAX_VALUE + ) + ), + new PageConsumerOperator(p -> readInto(actual, p)) + ) + ) { + new TestDriverRunner().run(driver); + } + return actual; + } + + private static Comparator> comparatorFromSortOrders(List sortOrders) { + return (row1, row2) -> { + assertEquals(row1.size(), row2.size()); + for (SortOrder order : sortOrders) { + int cmp = compareValues(order).compare(row1.get(order.channel()), row2.get(order.channel())); + if (cmp != 0) { + return cmp; + } + } + return 0; + }; + } + + private static boolean isSorted(List> values, Comparator> comparator) { + return IntStream.range(1, values.size()).allMatch(i -> comparator.compare(values.get(i - 1), values.get(i)) <= 0); + } + + private static Comparator compareValues(SortOrder order) { + Comparator baseComparator = order.asc() ? CASTING_COMPARATOR : CASTING_COMPARATOR.reversed(); + return order.nullsFirst() ? Comparator.nullsFirst(baseComparator) : Comparator.nullsLast(baseComparator); + } + + @SuppressWarnings("unchecked") + private static final Comparator CASTING_COMPARATOR = (o1, o2) -> ((Comparable) o1).compareTo(o2); + + private List> computeTopN(List> inputValues, int limit, boolean ascendingOrder) { + List> rows = inputValues.stream().map(e -> Arrays.asList(e.v1(), e.v2())).toList(); + return computeTopN(rows, List.of(1), List.of(new SortOrder(0, ascendingOrder, false)), limit).stream() + .map(l -> Tuple.tuple((Long) l.get(0), (Long) l.get(1))) + .toList(); + } + + private List> computeTopN( + List> inputValues, + List groupChannels, + List sortOrders, + int limit + ) { + Comparator> comparator = comparatorFromSortOrders(sortOrders); + + Map, List>> grouped = inputValues.stream() + .collect(Collectors.groupingBy(row -> groupChannels.stream().map(row::get).toList())); + + List> topNExpectedValues = new ArrayList<>(); + for (List> groupRows : grouped.values()) { + List> sortedGroup = groupRows.stream().sorted(comparator).limit(limit).toList(); + topNExpectedValues.addAll(sortedGroup); + } + topNExpectedValues.sort(comparator); + return topNExpectedValues; + } + +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 2952c2a3bff0f..a59752553d3c9 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -9,11 +9,14 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.tests.util.RamUsageTester; +import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHashTable; import org.elasticsearch.common.util.LimitedBreaker; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.data.Block; @@ -26,7 +29,6 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.AlwaysReferencedIndexedByShardId; -import org.elasticsearch.compute.lucene.IndexedByShardId; import org.elasticsearch.compute.lucene.IndexedByShardIdFromList; import org.elasticsearch.compute.operator.CountingCircuitBreaker; import org.elasticsearch.compute.operator.Driver; @@ -35,6 +37,7 @@ import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.topn.TopNOperator.InputOrdering; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.compute.test.CannedSourceOperator; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.TestBlockBuilder; @@ -42,15 +45,14 @@ import org.elasticsearch.compute.test.TestDriverFactory; import org.elasticsearch.compute.test.TestDriverRunner; import org.elasticsearch.compute.test.operator.blocksource.SequenceLongBlockSourceOperator; -import org.elasticsearch.compute.test.operator.blocksource.TupleAbstractBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleDocLongBlockSourceOperator; import org.elasticsearch.compute.test.operator.blocksource.TupleLongLongBlockSourceOperator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.core.Tuple; import org.elasticsearch.indices.CrankyCircuitBreakerService; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ListMatcher; import org.elasticsearch.xpack.versionfield.Version; import org.hamcrest.Matcher; @@ -68,6 +70,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BiFunction; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -91,6 +94,7 @@ import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_SORTABLE; import static org.elasticsearch.compute.operator.topn.TopNEncoder.DEFAULT_UNSORTABLE; import static org.elasticsearch.compute.operator.topn.TopNEncoder.UTF8; +import static org.elasticsearch.compute.operator.topn.TopNEncoderTests.randomPointAsWKB; import static org.elasticsearch.compute.test.BlockTestUtils.append; import static org.elasticsearch.compute.test.BlockTestUtils.randomValue; import static org.elasticsearch.compute.test.BlockTestUtils.readInto; @@ -101,13 +105,43 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.number.OrderingComparison.lessThanOrEqualTo; public class TopNOperatorTests extends OperatorTestCase { + protected final int pageSize = randomPageSize(); + + /** + * Accumulator for {@link RamUsageTester} that excludes shared objects not owned by the operator. + */ + protected static final RamUsageTester.Accumulator RAM_USAGE_ACCUMULATOR = new RamUsageTester.Accumulator() { + @Override + public long accumulateObject(Object o, long shallowSize, Map fieldValues, Collection queue) { + if (o instanceof ElementType) { + return 0; // shared + } + if (o instanceof TopNEncoder) { + return 0; // shared + } + if (o instanceof CircuitBreaker) { + return 0; // shared + } + if (o instanceof BlockFactory) { + return 0; // shared + } + if (o instanceof BigArrays) { + return 0; // shared + } + if (o instanceof BytesRefHashTable h) { + return h.ramBytesUsed(); + } + return super.accumulateObject(o, shallowSize, fieldValues, queue); + } + }; + // versions taken from org.elasticsearch.xpack.versionfield.VersionTests private static final List VERSIONS = List.of( "1", @@ -148,10 +182,12 @@ public class TopNOperatorTests extends OperatorTestCase { "1.2.3-rc1" ); - private final int pageSize = randomPageSize(); + protected int[] groupKeys() { + return new int[0]; + } @Override - protected TopNOperator.TopNOperatorFactory simple(SimpleOptions options) { + protected Operator.OperatorFactory simple(SimpleOptions options) { List elementTypes = List.of(LONG); List encoders = List.of(DEFAULT_UNSORTABLE); List sortOrders = List.of(new TopNOperator.SortOrder(0, true, false)); @@ -188,21 +224,7 @@ protected Matcher expectedToStringOfSimple() { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new SequenceLongBlockSourceOperator( - blockFactory, - LongStream.range(0, size).map(l -> ESTestCase.randomLong()), - between(1, size * 2) - ); - } - - protected SourceOperator simpleInput(BlockFactory blockFactory, int size, InputOrdering sortedInput) { - var longs = LongStream.range(0, size).map(l -> ESTestCase.randomLong()); - - return new SequenceLongBlockSourceOperator( - blockFactory, - sortedInput == InputOrdering.SORTED ? longs.sorted() : longs, - between(1, size * 2) - ); + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong()), between(1, size * 2)); } @Override @@ -228,54 +250,41 @@ protected void assertSimpleOutput(List input, List results) { ); } - public void testRamBytesUsed() { - for (var sortedInput : InputOrdering.values()) { - RamUsageTester.Accumulator acc = new RamUsageTester.Accumulator() { - @Override - public long accumulateObject(Object o, long shallowSize, Map fieldValues, Collection queue) { - if (o instanceof ElementType) { - return 0; // shared - } - if (o instanceof TopNEncoder) { - return 0; // shared - } - if (o instanceof CircuitBreaker) { - return 0; // shared - } - if (o instanceof BlockFactory) { - return 0; // shard - } - return super.accumulateObject(o, shallowSize, fieldValues, queue); - } - }; - int topCount = 10_000; - // We under-count by a few bytes because of the lists. In that end that's fine, but we need to account for it here. - long underCount = 200; - DriverContext context = driverContext(); - try ( - TopNOperator op = new TopNOperator.TopNOperatorFactory( - topCount, - List.of(LONG), - List.of(DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false)), - pageSize, - randomJumboPageBytes(), - sortedInput, - null - ).get(context) - ) { - long actualEmpty = RamUsageTester.ramUsed(op, acc); - assertThat(op.ramBytesUsed(), both(greaterThan(actualEmpty - underCount)).and(lessThan(actualEmpty))); - // But when we fill it then we're quite close - for (Page p : CannedSourceOperator.collectPages(simpleInput(context.blockFactory(), topCount, sortedInput))) { - op.addInput(p); - } - long actualFull = RamUsageTester.ramUsed(op, acc); - assertThat(op.ramBytesUsed(), both(greaterThan(actualFull - underCount)).and(lessThan(actualFull))); - - // TODO empty it again and check. - } + /** + * Creates the appropriate top-N operator factory for the test subclass. + */ + protected Operator.OperatorFactory createTopNOperatorFactory( + int topCount, + List elementTypes, + List encoders, + List sortOrders, + int[] groupKeys, + int maxPageSize, + long jumboPageBytes, + TopNOperator.InputOrdering inputOrdering, + @Nullable SharedMinCompetitive.Supplier minCompetitive + ) { + if (groupKeys.length > 0) { + return new GroupedTopNOperator.GroupedTopNOperatorFactory( + topCount, + elementTypes, + encoders, + sortOrders, + IntStream.of(groupKeys).boxed().toList(), + maxPageSize, + jumboPageBytes + ); } + return new TopNOperator.TopNOperatorFactory( + topCount, + elementTypes, + encoders, + sortOrders, + maxPageSize, + jumboPageBytes, + inputOrdering, + minCompetitive + ); } public void testRandomTopN() { @@ -294,11 +303,10 @@ public void testRandomTopNCranky() { } } - private void testRandomTopN(boolean asc, DriverContext context) { + protected void testRandomTopN(boolean asc, DriverContext context) { int limit = randomIntBetween(1, 20); - List inputValues = randomList(0, 5000, ESTestCase::randomLong); - Comparator comparator = asc ? naturalOrder() : reverseOrder(); - List expectedValues = inputValues.stream().sorted(comparator).limit(limit).toList(); + List inputValues = randomList(0, 5000, () -> randomLong()); + List expectedValues = inputValues.stream().sorted(asc ? naturalOrder() : reverseOrder()).limit(limit).toList(); List outputValues = topNLong(context, inputValues, limit, asc, false); assertThat(outputValues, equalTo(expectedValues)); } @@ -331,39 +339,44 @@ private void testTopNSortedInput( boolean asc, boolean nullsFirst ) { - List pages = new ArrayList<>(pagesValues.size()); + var layout = ChannelLayout.forDataColumns(1, groupKeys()); + List pages = new ArrayList<>(pagesValues.size()); for (var pageValues : pagesValues) { assert isSorted(pageValues, asc, nullsFirst); - List blocks = new ArrayList<>(pageValues.size()); - try (Block.Builder column = INT.newBlockBuilder(8, driverContext().blockFactory());) { + try (Block.Builder column = INT.newBlockBuilder(8, driverContext().blockFactory())) { for (var value : pageValues) { append(column, value); } - blocks.add(column.build()); + pages.add(new Page(layout.buildPageBlocks(pageValues.size(), driverContext().blockFactory(), column.build()))); } - pages.add(new Page(blocks.toArray(Block[]::new))); } List> actual = new ArrayList<>(); DriverContext driverContext = driverContext(); + List elementTypes = new ArrayList<>(); + List encoders = new ArrayList<>(); + for (int ch = 0; ch < layout.totalChannels(); ch++) { + elementTypes.add(INT); + encoders.add(layout.groupKeySet().contains(ch) ? DEFAULT_UNSORTABLE : DEFAULT_SORTABLE); + } + try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(pages.iterator()), List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topNCount, - List.of(INT), - List.of(DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, asc, nullsFirst)), - pageSize, + elementTypes, + encoders, + List.of(new TopNOperator.SortOrder(layout.dataChannel(0), asc, nullsFirst)), + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(p -> readInto(actual, p)) ) @@ -371,7 +384,7 @@ private void testTopNSortedInput( new TestDriverRunner().run(driver); } - assertThat(actual, equalTo(expectedResult)); + assertThat(layout.extractDataColumns(actual), equalTo(expectedResult)); } private void testTopNSortedInputWithTwoColumns( @@ -381,11 +394,10 @@ private void testTopNSortedInputWithTwoColumns( boolean asc, boolean nullsFirst ) { - List pages = new ArrayList<>(pagesValues.size()); + var layout = ChannelLayout.forDataColumns(2, groupKeys()); + List pages = new ArrayList<>(pagesValues.size()); for (var pageValues : pagesValues) { - List blocks = new ArrayList<>(pageValues.size()); - try ( Block.Builder firstColumn = INT.newBlockBuilder(8, driverContext().blockFactory()); Block.Builder secondColumn = INT.newBlockBuilder(8, driverContext().blockFactory()); @@ -394,31 +406,42 @@ private void testTopNSortedInputWithTwoColumns( append(firstColumn, value.v1()); append(secondColumn, value.v2()); } - blocks.add(firstColumn.build()); - blocks.add(secondColumn.build()); + pages.add( + new Page( + layout.buildPageBlocks(pageValues.size(), driverContext().blockFactory(), firstColumn.build(), secondColumn.build()) + ) + ); } - pages.add(new Page(blocks.toArray(Block[]::new))); } List> actual = new ArrayList<>(); DriverContext driverContext = driverContext(); + List elementTypes = new ArrayList<>(); + List encoders = new ArrayList<>(); + for (int ch = 0; ch < layout.totalChannels(); ch++) { + elementTypes.add(INT); + encoders.add(layout.groupKeySet().contains(ch) ? DEFAULT_UNSORTABLE : DEFAULT_SORTABLE); + } + try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(pages.iterator()), List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topNCount, - List.of(INT, INT), - List.of(DEFAULT_SORTABLE, DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, asc, nullsFirst), new TopNOperator.SortOrder(1, asc, nullsFirst)), - pageSize, + elementTypes, + encoders, + List.of( + new TopNOperator.SortOrder(layout.dataChannel(0), asc, nullsFirst), + new TopNOperator.SortOrder(layout.dataChannel(1), asc, nullsFirst) + ), + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(p -> readInto(actual, p)) ) @@ -426,24 +449,39 @@ private void testTopNSortedInputWithTwoColumns( new TestDriverRunner().run(driver); } - assertThat(actual, equalTo(expectedResult)); + assertThat(layout.extractDataColumns(actual), equalTo(expectedResult)); } public final void testTopNWithSortedInputToString() { - var topN = new TopNOperator.TopNOperatorFactory( + var topN = createTopNOperatorFactory( 4, List.of(LONG), List.of(DEFAULT_UNSORTABLE), List.of(new TopNOperator.SortOrder(0, true, false)), + groupKeys(), pageSize, randomJumboPageBytes(), InputOrdering.SORTED, null ); - var expectedDescription = "TopNOperator[count=0/4, elementTypes=[LONG], encoders=[DefaultUnsortable], " - + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]], inputOrdering=SORTED]"; - try (Operator operator = topN.get(driverContext())) { - assertThat(operator.toString(), equalTo(expectedDescription)); + if (groupKeys().length > 0) { + var expectedDescription = "GroupedTopNOperator[count=0/0/4" + + ", elementTypes=[LONG], encoders=[DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]" + + ", groupKeys=" + + Arrays.toString(groupKeys()) + + "]"; + try (Operator operator = topN.get(driverContext())) { + assertThat(operator.toString(), equalTo(expectedDescription)); + } + } else { + var expectedDescription = "TopNOperator[count=0/4" + + ", elementTypes=[LONG], encoders=[DefaultUnsortable], " + + "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]" + + ", inputOrdering=SORTED]"; + try (Operator operator = topN.get(driverContext())) { + assertThat(operator.toString(), equalTo(expectedDescription)); + } } } @@ -699,7 +737,11 @@ public void testBasicTopN() { assertThat(topNLong(values, 100, false, true), equalTo(Arrays.asList(null, null, 100L, 20L, 10L, 5L, 4L, 4L, 2L, 1L))); } - private List topNLong( + protected List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { + return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); + } + + protected List topNLong( DriverContext driverContext, List inputValues, int limit, @@ -711,7 +753,8 @@ private List topNLong( inputValues.stream().map(v -> tuple(v, 0L)).toList(), limit, List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(0, ascendingOrder, nullsFirst)) + List.of(new TopNOperator.SortOrder(0, ascendingOrder, nullsFirst)), + groupKeys() ).stream().map(Tuple::v1).toList(); } @@ -719,10 +762,6 @@ private static TupleLongLongBlockSourceOperator longLongSourceOperator(DriverCon return new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)); } - private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { - return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); - } - public void testCompareInts() { BlockFactory blockFactory = blockFactory(); testCompare( @@ -813,14 +852,14 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder for (int b = 0; b < page.getBlockCount(); b++) { // Non-null identity for (int p = 0; p < page.getPositionCount(); p++) { - TopNOperator.Row row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), page, p); + TopNRow row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), page, p); assertThat(row, equalTo(row)); assertThat(row.compareTo(row), equalTo(0)); } // Null identity for (int p = 0; p < page.getPositionCount(); p++) { - TopNOperator.Row row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), nullPage, p); + TopNRow row = row(elementType, encoder, b, randomBoolean(), randomBoolean(), nullPage, p); assertThat(row, equalTo(row)); assertThat(row.compareTo(row), equalTo(0)); } @@ -828,8 +867,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // nulls first for (int p = 0; p < page.getPositionCount(); p++) { boolean asc = randomBoolean(); - TopNOperator.Row nonNullRow = row(elementType, encoder, b, asc, true, page, p); - TopNOperator.Row nullRow = row(elementType, encoder, b, asc, true, nullPage, p); + TopNRow nonNullRow = row(elementType, encoder, b, asc, true, page, p); + TopNRow nullRow = row(elementType, encoder, b, asc, true, nullPage, p); assertThat(nonNullRow, not(equalTo(nullRow))); assertThat(nonNullRow, lessThan(nullRow)); assertThat(nullRow, greaterThan(nonNullRow)); @@ -838,8 +877,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // nulls last for (int p = 0; p < page.getPositionCount(); p++) { boolean asc = randomBoolean(); - TopNOperator.Row nonNullRow = row(elementType, encoder, b, asc, false, page, p); - TopNOperator.Row nullRow = row(elementType, encoder, b, asc, false, nullPage, p); + TopNRow nonNullRow = row(elementType, encoder, b, asc, false, page, p); + TopNRow nullRow = row(elementType, encoder, b, asc, false, nullPage, p); assertThat(nonNullRow, not(equalTo(nullRow))); assertThat(nonNullRow, greaterThan(nullRow)); assertThat(nullRow, lessThan(nonNullRow)); @@ -848,8 +887,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // ascending { boolean nullsFirst = randomBoolean(); - TopNOperator.Row r1 = row(elementType, encoder, b, true, nullsFirst, page, 0); - TopNOperator.Row r2 = row(elementType, encoder, b, true, nullsFirst, page, 1); + TopNRow r1 = row(elementType, encoder, b, true, nullsFirst, page, 0); + TopNRow r2 = row(elementType, encoder, b, true, nullsFirst, page, 1); assertThat(r1, not(equalTo(r2))); assertThat(r1, greaterThan(r2)); assertThat(r2, lessThan(r1)); @@ -857,8 +896,8 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder // descending { boolean nullsFirst = randomBoolean(); - TopNOperator.Row r1 = row(elementType, encoder, b, false, nullsFirst, page, 0); - TopNOperator.Row r2 = row(elementType, encoder, b, false, nullsFirst, page, 1); + TopNRow r1 = row(elementType, encoder, b, false, nullsFirst, page, 0); + TopNRow r2 = row(elementType, encoder, b, false, nullsFirst, page, 1); assertThat(r1, not(equalTo(r2))); assertThat(r1, lessThan(r2)); assertThat(r2, greaterThan(r1)); @@ -867,7 +906,7 @@ private void testCompare(Page page, ElementType elementType, TopNEncoder encoder page.releaseBlocks(); } - private TopNOperator.Row row( + private TopNRow row( ElementType elementType, TopNEncoder encoder, int channel, @@ -886,7 +925,7 @@ private TopNOperator.Row row( channelInKey, page ); - TopNOperator.Row row = new TopNOperator.Row(nonBreakingBigArrays().breakerService().getBreaker("request"), 0, 0); + TopNRow row = new TopNRow(nonBreakingBigArrays().breakerService().getBreaker("request"), 0, 0); rf.writeKey(position, row); rf.writeValues(position, row); return row; @@ -900,7 +939,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, false)) + List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, false)), + groupKeys() ), equalTo(List.of(tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null), tuple(null, 1L), tuple(null, null))) ); @@ -910,7 +950,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, true), new TopNOperator.SortOrder(1, true, false)) + List.of(new TopNOperator.SortOrder(0, true, true), new TopNOperator.SortOrder(1, true, false)), + groupKeys() ), equalTo(List.of(tuple(null, 1L), tuple(null, null), tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null))) ); @@ -920,7 +961,8 @@ public void testTopNTwoColumns() { values, 5, List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), - List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, true)) + List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, true)), + groupKeys() ), equalTo(List.of(tuple(1L, null), tuple(1L, 1L), tuple(1L, 2L), tuple(null, null), tuple(null, 1L))) ); @@ -933,13 +975,11 @@ public void testCollectAllValues() { int size = 10; int topCount = 3; List blocks = new ArrayList<>(); - List> expectedTop = new ArrayList<>(); + List> rawValues = new ArrayList<>(); IntBlock keys = blockFactory.newIntArrayVector(IntStream.range(0, size).toArray(), size).asBlock(); - List topKeys = new ArrayList<>(IntStream.range(size - topCount, size).boxed().toList()); - Collections.reverse(topKeys); - expectedTop.add(topKeys); blocks.add(keys); + rawValues.add(IntStream.range(0, size).mapToObj(Integer::valueOf).toList()); List elementTypes = new ArrayList<>(); List encoders = new ArrayList<>(); @@ -954,39 +994,36 @@ public void testCollectAllValues() { } elementTypes.add(e); encoders.add(nonKeyEncoder(e)); - List eTop = new ArrayList<>(); try (Block.Builder builder = e.newBlockBuilder(size, driverContext().blockFactory())) { + var rawValuesForElement = new ArrayList<>(size); for (int i = 0; i < size; i++) { Object value = randomValue(e); append(builder, value); - if (i >= size - topCount) { - eTop.add(value); - } + rawValuesForElement.add(value); } - Collections.reverse(eTop); blocks.add(builder.build()); - expectedTop.add(eTop); + rawValues.add(rawValuesForElement); } } List> actualTop = new ArrayList<>(); + List sortOrders = List.of(new TopNOperator.SortOrder(0, false, false)); try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of( - new TopNOperator( - blockFactory, - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topCount, elementTypes, encoders, - List.of(new TopNOperator.SortOrder(0, false, false)), - pageSize, + sortOrders, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(page -> readInto(actualTop, page)) ) @@ -994,10 +1031,55 @@ public void testCollectAllValues() { new TestDriverRunner().run(driver); } - assertMap(actualTop, matchesList(expectedTop)); + assertMap(actualTop, matchesList(expectedTop(rawValues, sortOrders, topCount))); assertDriverContext(driverContext); } + protected List> expectedTop(List> input, List sortOrders, int topCount) { + // input is channel-oriented, transpose to row-oriented for processing and then back format. + return transpose(expectedTopRowOriented(transpose(input), sortOrders, topCount)); + } + + protected List> expectedTopRowOriented( + List> rowOriented, + List sortOrders, + int topCount + ) { + Comparator> comparator = (row1, row2) -> { + for (TopNOperator.SortOrder order : sortOrders) { + Object v1 = row1.get(order.channel()); + Object v2 = row2.get(order.channel()); + boolean firstIsNull = v1 == null; + boolean secondIsNull = v2 == null; + + if (firstIsNull || secondIsNull) { + int nullCompare = Boolean.compare(firstIsNull, secondIsNull) * (order.nullsFirst() ? -1 : 1); + if (nullCompare != 0) { + return nullCompare; + } + continue; + } + + @SuppressWarnings("unchecked") + int cmp = ((Comparable) v1).compareTo(v2); + if (cmp != 0) { + return order.asc() ? cmp : -cmp; + } + } + return 0; + }; + + return rowOriented.stream().sorted(comparator).limit(topCount).toList(); + } + + private static List> transpose(List> input) { + if (input.isEmpty()) { + return new ArrayList<>(); + } + int numRows = input.getFirst().size(); + return IntStream.range(0, numRows).mapToObj(row -> input.stream().map(channel -> channel.get(row)).toList()).toList(); + } + public void testCollectAllValues_RandomMultiValues() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); @@ -1006,13 +1088,11 @@ public void testCollectAllValues_RandomMultiValues() { int topCount = 3; int blocksCount = 20; List blocks = new ArrayList<>(); - List> expectedTop = new ArrayList<>(); + List> rawValues = new ArrayList<>(); IntBlock keys = blockFactory.newIntArrayVector(IntStream.range(0, rows).toArray(), rows).asBlock(); - List topKeys = new ArrayList<>(IntStream.range(rows - topCount, rows).boxed().toList()); - Collections.reverse(topKeys); - expectedTop.add(topKeys); blocks.add(keys); + rawValues.add(IntStream.range(0, rows).mapToObj(Integer::valueOf).toList()); List elementTypes = new ArrayList<>(blocksCount); List encoders = new ArrayList<>(blocksCount); @@ -1033,57 +1113,50 @@ public void testCollectAllValues_RandomMultiValues() { } elementTypes.add(e); encoders.add(nonKeyEncoder(e)); - List eTop = new ArrayList<>(); + List channelValues = new ArrayList<>(); try (Block.Builder builder = e.newBlockBuilder(rows, driverContext().blockFactory())) { for (int i = 0; i < rows; i++) { if (e != ElementType.DOC && e != ElementType.NULL && randomBoolean()) { // generate a multi-value block int mvCount = randomIntBetween(5, 10); - List eTopList = new ArrayList<>(mvCount); + List mvValues = new ArrayList<>(mvCount); builder.beginPositionEntry(); for (int j = 0; j < mvCount; j++) { Object value = randomValue(e); append(builder, value); - if (i >= rows - topCount) { - eTopList.add(value); - } + mvValues.add(value); } builder.endPositionEntry(); - if (i >= rows - topCount) { - eTop.add(eTopList); - } + channelValues.add(mvValues); } else { Object value = randomValue(e); append(builder, value); - if (i >= rows - topCount) { - eTop.add(value); - } + channelValues.add(value); } } - Collections.reverse(eTop); blocks.add(builder.build()); - expectedTop.add(eTop); + rawValues.add(channelValues); } } List> actualTop = new ArrayList<>(); + List sortOrders = List.of(new TopNOperator.SortOrder(0, false, false)); try ( Driver driver = TestDriverFactory.create( driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of( - new TopNOperator( - blockFactory, - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( topCount, elementTypes, encoders, - List.of(new TopNOperator.SortOrder(0, false, false)), - pageSize, + sortOrders, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(page -> readInto(actualTop, page)) ) @@ -1091,7 +1164,7 @@ public void testCollectAllValues_RandomMultiValues() { new TestDriverRunner().run(driver); } - assertMap(actualTop, matchesList(expectedTop)); + assertMap(actualTop, matchesList(expectedTop(rawValues, sortOrders, topCount))); assertDriverContext(driverContext); } @@ -1103,37 +1176,36 @@ private static TopNEncoder nonKeyEncoder(ElementType elementType) { }; } - private List> topNTwoLongColumns( + protected List> topNTwoLongColumns( DriverContext driverContext, List> values, int limit, List encoder, - List sortOrders + List sortOrders, + int[] groupKeys ) { - var pages = topNTwoColumns( + var pages = topNMultipleColumns( driverContext, new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)), - AlwaysReferencedIndexedByShardId.INSTANCE, limit, encoder, - sortOrders + sortOrders, + groupKeys ); - var result = pageToTuples( + return pageToTuples( (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), pages ); - assertThat(result, hasSize(Math.min(limit, values.size()))); - return result; } - private List topNTwoColumns( + protected List topNMultipleColumns( DriverContext driverContext, - TupleAbstractBlockSourceOperator sourceOperator, - IndexedByShardId shardRefCounters, + AbstractTypedBlockSourceOperator sourceOperator, int limit, List encoder, - List sortOrders + List sortOrders, + int[] groupKeys ) { var pages = new ArrayList(); boolean success = false; @@ -1143,18 +1215,17 @@ private List topNTwoColumns( driverContext, sourceOperator, List.of( - new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + createTopNOperatorFactory( limit, sourceOperator.elementTypes(), encoder, sortOrders, - pageSize, + groupKeys, + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null - ) + ).get(driverContext) ), new PageConsumerOperator(pages::add) ) @@ -1171,7 +1242,7 @@ private List topNTwoColumns( return pages; } - private static List> pageToTuples( + protected static List> pageToTuples( BiFunction getFirstBlockValue, BiFunction getSecondBlockValue, List pages @@ -1196,12 +1267,13 @@ private static List> pageToTuples( public void testTopNManyDescriptionAndToString() { int fixedLength = between(1, 100); - TopNOperator.TopNOperatorFactory factory = new TopNOperator.TopNOperatorFactory( + Operator.OperatorFactory factory = createTopNOperatorFactory( 10, List.of(BYTES_REF, BYTES_REF), List.of(UTF8, new FixedLengthAscTopNEncoder(fixedLength)), List.of(new TopNOperator.SortOrder(0, false, false), new TopNOperator.SortOrder(1, false, true)), - pageSize, + groupKeys(), + randomPageSize(), randomJumboPageBytes(), InputOrdering.NOT_SORTED, null @@ -1209,14 +1281,28 @@ public void testTopNManyDescriptionAndToString() { String sorts = List.of("SortOrder[channel=0, asc=false, nullsFirst=false]", "SortOrder[channel=1, asc=false, nullsFirst=true]") .stream() .collect(Collectors.joining(", ")); - String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" - + fixedLength - + "]], sortOrders=[" - + sorts - + "], inputOrdering=NOT_SORTED]"; - assertThat(factory.describe(), equalTo("TopNOperator[count=10" + tail)); - try (Operator operator = factory.get(driverContext())) { - assertThat(operator.toString(), equalTo("TopNOperator[count=0/10" + tail)); + if (groupKeys().length > 0) { + String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" + + fixedLength + + "]], sortOrders=[" + + sorts + + "], groupKeys=" + + Arrays.toString(groupKeys()) + + "]"; + assertThat(factory.describe(), equalTo("GroupedTopNOperator[count=10" + tail)); + try (Operator operator = factory.get(driverContext())) { + assertThat(operator.toString(), equalTo("GroupedTopNOperator[count=0/0/10" + tail)); + } + } else { + String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[Utf8Asc, FixedLengthAsc[" + + fixedLength + + "]], sortOrders=[" + + sorts + + "], inputOrdering=NOT_SORTED]"; + assertThat(factory.describe(), equalTo("TopNOperator[count=10" + tail)); + try (Operator operator = factory.get(driverContext())) { + assertThat(operator.toString(), equalTo("TopNOperator[count=0/10" + tail)); + } } } @@ -1468,12 +1554,10 @@ public void testRandomMultiValuesTopN() { int blocksCount = between(20, 30); int sortingByColumns = between(1, 10); - Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); - List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); - List> randomValueSuppliers = new ArrayList<>(blocksCount); boolean[] validSortKeys = new boolean[blocksCount]; List elementTypes = new ArrayList<>(blocksCount); List encoders = new ArrayList<>(blocksCount); + List> randomValueSuppliers = new ArrayList<>(blocksCount); for (int type = 0; type < blocksCount; type++) { ElementType e = randomValueOtherThanMany( @@ -1519,6 +1603,13 @@ public void testRandomMultiValuesTopN() { randomValueSuppliers.add(randomValueSupplier); } + int[] gk = groupKeys(); + var layout = ChannelLayout.forDataColumns(blocksCount, gk); + if (gk.length > 0) { + layout.insertGroupKeyEntries(elementTypes, INT); + layout.insertGroupKeyEntries(encoders, DEFAULT_UNSORTABLE); + } + /* * Build sort keys, making sure not to include duplicates. This could * build fewer than the desired sort columns, but it's more important @@ -1527,30 +1618,27 @@ public void testRandomMultiValuesTopN() { * not to include sort keys that simulate geo objects. Those aren't * sortable at all. */ + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); for (int i = 0; i < sortingByColumns; i++) { int column = randomValueOtherThanMany(c -> false == validSortKeys[c], () -> randomIntBetween(0, blocksCount - 1)); - uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + uniqueOrders.add(new TopNOperator.SortOrder(layout.dataChannel(column), randomBoolean(), randomBoolean())); } List sortOrders = uniqueOrders.stream().toList(); - NaiveTopNComparator comparator = new NaiveTopNComparator(sortOrders); - SharedMinCompetitive.Supplier minCompetitiveSupplier = randomBoolean() - ? new SharedMinCompetitive.Supplier(blockFactory().breaker(), keyConfigs(elementTypes, encoders, sortOrders)) - : null; - SharedMinCompetitive minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); + List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); + List>> actualValues = new ArrayList<>(); try ( - TopNOperator operator = new TopNOperator( - driverContext.blockFactory(), - nonBreakingBigArrays().breakerService().getBreaker("request"), + Operator operator = createTopNOperatorFactory( topCount, elementTypes, encoders, sortOrders, + gk, rowsPerPage, Long.MAX_VALUE, InputOrdering.NOT_SORTED, - minCompetitiveSupplier - ) + null + ).get(driverContext) ) { for (int p = 0; p < pageCount; p++) { assertThat(operator.needsInput(), equalTo(true)); @@ -1558,40 +1646,30 @@ public void testRandomMultiValuesTopN() { assertThat(operator.getOutput(), nullValue()); for (int r = 0; r < rowsPerPage; r++) { - expectedValues.add(new ArrayList<>(blocksCount)); + expectedValues.add(new ArrayList<>(blocksCount + gk.length)); } - Block[] blocks = new Block[blocksCount]; + Block[] dataBlocks = new Block[blocksCount]; for (int b = 0; b < blocksCount; b++) { - ElementType elementType = elementTypes.get(b); + ElementType elementType = elementTypes.get(layout.dataChannel(b)); try (Block.Builder builder = elementType.newBlockBuilder(rowsPerPage, driverContext().blockFactory())) { List previousValue = null; - for (int r = 0; r < rowsPerPage; r++) { List values = new ArrayList<>(); - // let's make things a bit more real for this TopN sorting: have some "equal" values in different rows for the - // same - // block if (rarely() && previousValue != null) { values = previousValue; } else { if (elementType != ElementType.NULL && randomBoolean()) { - // generate a multi-value block int mvCount = randomIntBetween(5, 10); for (int j = 0; j < mvCount; j++) { - Object value = randomValueSuppliers.get(b).get(); - values.add(value); + values.add(randomValueSuppliers.get(b).get()); } - } else {// null or single-valued value - Object value = randomValueSuppliers.get(b).get(); - values.add(value); + } else { + values.add(randomValueSuppliers.get(b).get()); } - if (usually() && randomBoolean()) { - // let's remember the "previous" value, maybe we'll use it again in a different row previousValue = values; } } - if (values.size() == 1) { append(builder, values.get(0)); } else { @@ -1601,59 +1679,45 @@ public void testRandomMultiValuesTopN() { } builder.endPositionEntry(); } - expectedValues.get(p * rowsPerPage + r).add(values); } - blocks[b] = builder.build(); + dataBlocks[b] = builder.build(); } } - operator.addInput(new Page(blocks)); - - if (minCompetitive != null) { - if ((p + 1) * rowsPerPage < topCount) { - assertThat(minCompetitive.get(blockFactory()), nullValue()); - } else { - List> minCompetitiveRow = expectedValues.stream() - .sorted(comparator) - .skip(topCount - 1) - .findFirst() - .get(); - try (Page min = minCompetitive.get(blockFactory())) { - assertThat(min.getBlockCount(), equalTo(sortOrders.size())); - for (int s = 0; s < min.getBlockCount(); s++) { - logger.info("checking key {}", s); - TopNOperator.SortOrder sort = sortOrders.get(s); - Object actual = BlockUtils.toJavaObject(min.getBlock(s), 0); - Object expected = reduceKey(minCompetitiveRow.get(sort.channel()), sort.asc()); - assertThat(actual, equalTo(expected)); - } - } + if (gk.length > 0) { + for (int r = 0; r < rowsPerPage; r++) { + layout.insertGroupKeyEntries(expectedValues.get(p * rowsPerPage + r), List.of(0)); } } + Block[] pageBlocks = gk.length > 0 + ? layout.buildPageBlocks(rowsPerPage, driverContext.blockFactory(), dataBlocks) + : dataBlocks; + operator.addInput(new Page(pageBlocks)); } operator.finish(); assertThat(operator.needsInput(), equalTo(false)); assertThat(operator.isFinished(), equalTo(false)); - - List>> actualValues = new ArrayList<>(); while (operator.isFinished() == false) { - try (Page p = operator.getOutput()) { + Page p = operator.getOutput(); + assertThat(operator.needsInput(), equalTo(false)); + if (p != null) { readAsRows(actualValues, p); - assertThat(operator.needsInput(), equalTo(false)); + p.releaseBlocks(); } } + } - List>> topNExpectedValues = expectedValues.stream().sorted(comparator).limit(topCount).toList(); - List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); - List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); + List>> topNExpectedValues = expectedValues.stream() + .sorted(new NaiveTopNComparator(sortOrders)) + .limit(topCount) + .toList(); + List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); + List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); - assertMap(actualReducedValues, matchesList(expectedReducedValues)); - } finally { - Releasables.close(minCompetitive); - } + assertMap(actualReducedValues, matchesList(expectedReducedValues)); } - private List keyConfigs( + protected List keyConfigs( List elementTypes, List encoders, List sortOrders @@ -2040,56 +2104,8 @@ public void testRowResizes() { } } - public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { - topNShardContextManagementAux(4, Stream.generate(() -> true).limit(4).toList()); - } - - public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { - topNShardContextManagementAux(2, List.of(true, false, false, true)); - } - - private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { - List> values = Arrays.asList( - tuple(new BlockUtils.Doc(0, 10, 100), 1L), - tuple(new BlockUtils.Doc(1, 20, 200), 2L), - tuple(new BlockUtils.Doc(2, 30, 300), null), - tuple(new BlockUtils.Doc(3, 40, 400), -3L) - ); - - List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); - var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); - var pages = topNTwoColumns(driverContext(), new TupleDocLongBlockSourceOperator(driverContext().blockFactory(), values) { - @Override - protected Block.Builder firstElementBlockBuilder(int length) { - return DocBlock.newBlockBuilder(blockFactory, length).shardRefCounters(shardRefCounters); - } - }, - shardRefCounters, - limit, - List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE), - List.of(new TopNOperator.SortOrder(1, true, false)) - ); - refCountedList.forEach(RefCounted::decRef); - - assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); - - var expectedValues = values.stream() - .sorted(Comparator.comparingLong(t -> t.v2() == null ? Long.MAX_VALUE : t.v2())) - .limit(limit) - .toList(); - assertThat( - pageToTuples((b, i) -> (BlockUtils.Doc) BlockUtils.toJavaObject(b, i), (b, i) -> ((LongBlock) b).getLong(i), pages), - equalTo(expectedValues) - ); - Releasables.close(pages); - - for (var rc : refCountedList) { - assertFalse(rc.hasReferences()); - } - } - @SuppressWarnings({ "unchecked", "rawtypes" }) - private static void readAsRows(List>> values, Page page) { + protected static void readAsRows(List>> values, Page page) { if (page.getBlockCount() == 0) { fail("No blocks returned!"); } @@ -2113,6 +2129,13 @@ private static void readAsRows(List>> values, Page page) { } } + protected static List> readAsRowsSingleValue(Page page) { + assertThat(page.getBlockCount(), greaterThan(0)); + return IntStream.range(0, page.getPositionCount()) + .mapToObj(position -> IntStream.range(0, page.getBlockCount()).mapToObj(i -> toJavaObject(page.getBlock(i), position)).toList()) + .toList(); + } + public void testSplitOnSize() { int topCount = 50; long jumboPageBytes = randomJumboPageBytes(); @@ -2139,19 +2162,20 @@ public void testSplitOnSize() { */ maxPageSize += PageCacheRecycler.PAGE_SIZE_IN_BYTES; } + int[] gk = groupKeys(); + int expectedTotal = gk.length > 0 ? inputPageRows * inputPageCount : topCount; try ( - TopNOperator op = new TopNOperator( - driverContext().blockFactory(), - factory.breaker(), + Operator op = createTopNOperatorFactory( topCount, List.of(BYTES_REF), List.of(UTF8), List.of(new TopNOperator.SortOrder(0, randomBoolean(), randomBoolean())), + gk, Integer.MAX_VALUE, jumboPageBytes, InputOrdering.NOT_SORTED, null - ) + ).get(driverContext()) ) { for (int p = 0; p < inputPageCount; p++) { try (BytesRefBlock.Builder bytes = factory.newBytesRefBlockBuilder(inputPageRows)) { @@ -2166,14 +2190,14 @@ public void testSplitOnSize() { while (op.isFinished() == false) { try (Page out = op.getOutput()) { totalPositions += out.getPositionCount(); - if (totalPositions < topCount) { + if (totalPositions < expectedTotal) { assertThat(out.ramBytesUsedByBlocks(), both(greaterThanOrEqualTo(minPageSize)).and(lessThanOrEqualTo(maxPageSize))); } else { assertThat(out.ramBytesUsedByBlocks(), lessThanOrEqualTo(maxPageSize)); } } } - assertThat(totalPositions, equalTo(topCount)); + assertThat(totalPositions, equalTo(expectedTotal)); } } @@ -2188,7 +2212,7 @@ public void testSplitOnSize() { * but it's as close as possible to a very general and fully randomized unit test for TopNOperator with multi-values support. */ @SuppressWarnings({ "unchecked", "rawtypes" }) - private List> extractAndReduceSortedValues(List>> rows, Set orders) { + protected List> extractAndReduceSortedValues(List>> rows, Set orders) { List> result = new ArrayList<>(rows.size()); for (List> row : rows) { @@ -2215,7 +2239,7 @@ static Object minMax(List values, boolean asc) { : values.stream().map(element -> (Comparable) element).max(naturalOrder()).get(); } - private class NaiveTopNComparator implements Comparator>> { + protected class NaiveTopNComparator implements Comparator>> { private final List orders; NaiveTopNComparator(List orders) { @@ -2251,6 +2275,562 @@ static Version randomVersion() { return new Version(randomFrom(VERSIONS)); } + protected static class RandomMultiValueBlocksResult { + final List>> expectedValues; + final List blocks; + final boolean[] validSortKeys; + final List elementTypes; + final List encoders; + + RandomMultiValueBlocksResult( + List>> expectedValues, + List blocks, + boolean[] validSortKeys, + List elementTypes, + List encoders + ) { + this.expectedValues = expectedValues; + this.blocks = blocks; + this.validSortKeys = validSortKeys; + this.elementTypes = elementTypes; + this.encoders = encoders; + } + } + + protected RandomMultiValueBlocksResult generateRandomMultiValueBlocks(int rows, int blocksCount, DriverContext driverContext) { + List>> expectedValues = new ArrayList<>(rows); + List blocks = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int i = 0; i < rows; i++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == COMPOSITE + || t == AGGREGATE_METRIC_DOUBLE + || t == EXPONENTIAL_HISTOGRAM + || t == TDIGEST + || t == LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + try (Block.Builder builder = e.newBlockBuilder(rows, driverContext.blockFactory())) { + List previousValue = null; + Function randomValueSupplier = (blockType) -> randomValue(blockType); + if (e == BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + encoders.add(TopNEncoder.IP); + yield (blockType) -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + encoders.add(TopNEncoder.VERSION); + yield (blockType) -> randomVersion().toBytesRef(); + } + case 2 -> { + encoders.add(DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield (blockType) -> randomPointAsWKB(); + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(UTF8); + } + } else { + encoders.add(DEFAULT_SORTABLE); + } + + for (int i = 0; i < rows; i++) { + List values = new ArrayList<>(); + if (rarely() && previousValue != null) { + values = previousValue; + } else { + if (e != ElementType.NULL && randomBoolean()) { + int mvCount = randomIntBetween(5, 10); + for (int j = 0; j < mvCount; j++) { + Object value = randomValueSupplier.apply(e); + values.add(value); + } + } else { + Object value = randomValueSupplier.apply(e); + values.add(value); + } + + if (usually() && randomBoolean()) { + previousValue = values; + } + } + + if (values.size() == 1) { + append(builder, values.get(0)); + } else { + builder.beginPositionEntry(); + for (Object o : values) { + append(builder, o); + } + builder.endPositionEntry(); + } + + expectedValues.get(i).add(values); + } + blocks.add(builder.build()); + } + } + + return new RandomMultiValueBlocksResult(expectedValues, blocks, validSortKeys, elementTypes, encoders); + } + + protected Set generateSortOrders( + int sortingByColumns, + int blocksCount, + boolean[] validSortKeys, + java.util.function.Predicate excludeColumn + ) { + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); + for (int i = 0; i < sortingByColumns; i++) { + int column = randomValueOtherThanMany( + c -> false == validSortKeys[c] || excludeColumn.test(c), + () -> randomIntBetween(0, blocksCount - 1) + ); + uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + } + return uniqueOrders; + } + + protected static class RandomBlocksResult { + final List> expectedValues; + final List blocks; + final boolean[] validSortKeys; + final List elementTypes; + final List encoders; + + RandomBlocksResult( + List> expectedValues, + List blocks, + boolean[] validSortKeys, + List elementTypes, + List encoders + ) { + this.expectedValues = expectedValues; + this.blocks = blocks; + this.validSortKeys = validSortKeys; + this.elementTypes = elementTypes; + this.encoders = encoders; + } + } + + protected RandomBlocksResult generateRandomSingleValueBlocks(int rows, int blocksCount, DriverContext driverContext) { + List> expectedValues = new ArrayList<>(rows); + List blocks = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int i = 0; i < rows; i++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == COMPOSITE + || t == AGGREGATE_METRIC_DOUBLE + || t == EXPONENTIAL_HISTOGRAM + || t == TDIGEST + || t == LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + try (Block.Builder builder = e.newBlockBuilder(rows, driverContext.blockFactory())) { + Function randomValueSupplier = (blockType) -> randomValue(blockType); + if (e == BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + encoders.add(TopNEncoder.IP); + yield (blockType) -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + encoders.add(TopNEncoder.VERSION); + yield (blockType) -> randomVersion().toBytesRef(); + } + case 2 -> { + encoders.add(DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield (blockType) -> randomPointAsWKB(); + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(UTF8); + } + } else { + encoders.add(DEFAULT_SORTABLE); + } + + for (int i = 0; i < rows; i++) { + Object value = randomValueSupplier.apply(e); + append(builder, value); + expectedValues.get(i).add(value); + } + blocks.add(builder.build()); + } + } + + return new RandomBlocksResult(expectedValues, blocks, validSortKeys, elementTypes, encoders); + } + + /** + * Maps group key channels and data channels for tests that need to add constant-value + * group key columns alongside their data columns. When {@code groupKeys} is empty, the + * layout is an identity mapping (data channel i == channel i). + */ + protected record ChannelLayout(int[] groupKeys, Set groupKeySet, int[] dataChannels, int totalChannels) { + + static ChannelLayout forDataColumns(int dataColumnCount, int[] groupKeys) { + int totalChannels = dataColumnCount + groupKeys.length; + Set groupKeySet = IntStream.of(groupKeys).boxed().collect(Collectors.toSet()); + int[] dataChannels = IntStream.range(0, totalChannels).filter(ch -> groupKeySet.contains(ch) == false).toArray(); + return new ChannelLayout(groupKeys, groupKeySet, dataChannels, totalChannels); + } + + /** Returns the actual channel index for the given data column index. */ + int dataChannel(int dataIndex) { + return dataChannels[dataIndex]; + } + + /** + * Builds a {@code Block[]} with constant-0 INT blocks at the group key channels + * and the supplied data blocks placed at the corresponding data channels. + */ + Block[] buildPageBlocks(int positions, BlockFactory blockFactory, Block... dataBlocks) { + Block[] blockArray = new Block[totalChannels]; + for (int g : groupKeys) { + try (Block.Builder groupCol = INT.newBlockBuilder(positions, blockFactory)) { + for (int j = 0; j < positions; j++) { + append(groupCol, 0); + } + blockArray[g] = groupCol.build(); + } + } + for (int d = 0; d < dataBlocks.length; d++) { + blockArray[dataChannels[d]] = dataBlocks[d]; + } + return blockArray; + } + + /** Extracts only the data-channel columns from column-oriented results. */ + List> extractDataColumns(List> actual) { + if (groupKeys.length == 0 || actual.isEmpty()) { + return actual; + } + return IntStream.of(dataChannels).mapToObj(actual::get).toList(); + } + + /** Inserts group key entries into pre-built lists at the correct positions. */ + void insertGroupKeyEntries(List list, T value) { + int[] sortedGk = IntStream.of(groupKeys).sorted().toArray(); + for (int g : sortedGk) { + list.add(g, value); + } + } + } + + public void testRamBytesUsed() { + int topCount = 10_000; + long underCount = groupKeys().length > 0 ? 1000 : 250; + DriverContext context = driverContext(); + int numColumns = 1 + groupKeys().length; + List elementTypes = Collections.nCopies(numColumns, LONG); + List encoders = Collections.nCopies(numColumns, DEFAULT_UNSORTABLE); + try ( + Operator op = createTopNOperatorFactory( + topCount, + elementTypes, + encoders, + List.of(new TopNOperator.SortOrder(0, true, false)), + groupKeys(), + pageSize, + Long.MAX_VALUE, + TopNOperator.InputOrdering.NOT_SORTED, + null + ).get(context) + ) { + Accountable accountable = (Accountable) op; + long actualEmpty = RamUsageTester.ramUsed(op, RAM_USAGE_ACCUMULATOR); + assertThat(accountable.ramBytesUsed(), both(greaterThan(actualEmpty - underCount)).and(lessThan(actualEmpty))); + for (Page p : CannedSourceOperator.collectPages(simpleInput(context.blockFactory(), topCount))) { + op.addInput(p); + } + long actualFull = RamUsageTester.ramUsed(op, RAM_USAGE_ACCUMULATOR); + assertThat(accountable.ramBytesUsed(), both(greaterThan(actualFull - underCount)).and(lessThan(actualFull))); + } + } + + public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + topNShardContextManagementAux(4, Stream.generate(() -> true).limit(4).toList()); + } + + public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + topNShardContextManagementAux(2, List.of(true, false, false, true)); + } + + private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { + List> values = Arrays.asList( + tuple(new BlockUtils.Doc(0, 10, 100), 1L), + tuple(new BlockUtils.Doc(1, 20, 200), 2L), + tuple(new BlockUtils.Doc(2, 30, 300), null), + tuple(new BlockUtils.Doc(3, 40, 400), -3L) + ); + + List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); + var shardRefCounters = new IndexedByShardIdFromList<>(refCountedList); + var pages = topNMultipleColumns(driverContext(), new TupleDocLongBlockSourceOperator(driverContext().blockFactory(), values) { + @Override + protected Block.Builder firstElementBlockBuilder(int length) { + return DocBlock.newBlockBuilder(blockFactory, length).shardRefCounters(shardRefCounters); + } + }, + limit, + List.of(new DocVectorEncoder(shardRefCounters), DEFAULT_UNSORTABLE), + List.of(new TopNOperator.SortOrder(1, true, false)), + groupKeys() + ); + refCountedList.forEach(RefCounted::decRef); + + assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); + + var expectedValues = values.stream() + .sorted(Comparator.comparingLong(t -> t.v2() == null ? Long.MAX_VALUE : t.v2())) + .limit(limit) + .toList(); + assertThat( + pageToTuples((b, i) -> (BlockUtils.Doc) BlockUtils.toJavaObject(b, i), (b, i) -> ((LongBlock) b).getLong(i), pages), + equalTo(expectedValues) + ); + Releasables.close(pages); + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } + } + + public void testRandomMultiValuesTopNWithMinCompetitive() { + assumeTrue("only applies to ungrouped TopN", groupKeys().length == 0); + DriverContext driverContext = driverContext(); + int pageCount = between(1, 100); + int rowsPerPage = between(50, 100); + int topCount = between(1, rowsPerPage * pageCount); + int blocksCount = between(20, 30); + int sortingByColumns = between(1, 10); + + Set uniqueOrders = new LinkedHashSet<>(sortingByColumns); + List>> expectedValues = new ArrayList<>(rowsPerPage * pageCount); + List> randomValueSuppliers = new ArrayList<>(blocksCount); + boolean[] validSortKeys = new boolean[blocksCount]; + List elementTypes = new ArrayList<>(blocksCount); + List encoders = new ArrayList<>(blocksCount); + + for (int type = 0; type < blocksCount; type++) { + ElementType e = randomValueOtherThanMany( + t -> t == ElementType.UNKNOWN + || t == ElementType.DOC + || t == ElementType.COMPOSITE + || t == ElementType.AGGREGATE_METRIC_DOUBLE + || t == ElementType.EXPONENTIAL_HISTOGRAM + || t == ElementType.TDIGEST + || t == ElementType.LONG_RANGE, + () -> randomFrom(ElementType.values()) + ); + elementTypes.add(e); + validSortKeys[type] = true; + Supplier randomValueSupplier = () -> randomValue(e); + if (e == ElementType.BYTES_REF) { + if (rarely()) { + randomValueSupplier = switch (randomInt(2)) { + case 0 -> { + // Simulate ips + encoders.add(TopNEncoder.IP); + yield () -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + case 1 -> { + // Simulate version fields + encoders.add(TopNEncoder.VERSION); + yield () -> randomVersion().toBytesRef(); + } + case 2 -> { + // Simulate geo_shape and geo_point + encoders.add(TopNEncoder.DEFAULT_UNSORTABLE); + validSortKeys[type] = false; + yield TopNEncoderTests::randomPointAsWKB; + } + default -> throw new UnsupportedOperationException(); + }; + } else { + encoders.add(TopNEncoder.UTF8); + } + } else { + encoders.add(TopNEncoder.DEFAULT_SORTABLE); + } + randomValueSuppliers.add(randomValueSupplier); + } + + for (int i = 0; i < sortingByColumns; i++) { + int column = randomValueOtherThanMany(c -> false == validSortKeys[c], () -> randomIntBetween(0, blocksCount - 1)); + uniqueOrders.add(new TopNOperator.SortOrder(column, randomBoolean(), randomBoolean())); + } + List sortOrders = uniqueOrders.stream().toList(); + NaiveTopNComparator comparator = new NaiveTopNComparator(sortOrders); + SharedMinCompetitive.Supplier minCompetitiveSupplier = randomBoolean() + ? new SharedMinCompetitive.Supplier(blockFactory().breaker(), keyConfigs(elementTypes, encoders, sortOrders)) + : null; + SharedMinCompetitive minCompetitive = minCompetitiveSupplier == null ? null : minCompetitiveSupplier.get(); + + try ( + TopNOperator operator = new TopNOperator( + driverContext.blockFactory(), + nonBreakingBigArrays().breakerService().getBreaker("request"), + topCount, + elementTypes, + encoders, + sortOrders, + rowsPerPage, + Long.MAX_VALUE, + TopNOperator.InputOrdering.NOT_SORTED, + minCompetitiveSupplier + ) + ) { + for (int p = 0; p < pageCount; p++) { + assertThat(operator.needsInput(), equalTo(true)); + assertThat(operator.isFinished(), equalTo(false)); + assertThat(operator.getOutput(), nullValue()); + + for (int r = 0; r < rowsPerPage; r++) { + expectedValues.add(new ArrayList<>(blocksCount)); + } + Block[] blocks = new Block[blocksCount]; + for (int b = 0; b < blocksCount; b++) { + ElementType elementType = elementTypes.get(b); + try (Block.Builder builder = elementType.newBlockBuilder(rowsPerPage, driverContext().blockFactory())) { + List previousValue = null; + for (int r = 0; r < rowsPerPage; r++) { + List values = new ArrayList<>(); + if (rarely() && previousValue != null) { + values = previousValue; + } else { + if (elementType != ElementType.NULL && randomBoolean()) { + int mvCount = randomIntBetween(5, 10); + for (int j = 0; j < mvCount; j++) { + values.add(randomValueSuppliers.get(b).get()); + } + } else { + values.add(randomValueSuppliers.get(b).get()); + } + if (usually() && randomBoolean()) { + previousValue = values; + } + } + if (values.size() == 1) { + append(builder, values.get(0)); + } else { + builder.beginPositionEntry(); + for (Object o : values) { + append(builder, o); + } + builder.endPositionEntry(); + } + expectedValues.get(p * rowsPerPage + r).add(values); + } + blocks[b] = builder.build(); + } + } + operator.addInput(new Page(blocks)); + + if (minCompetitive != null) { + if ((p + 1) * rowsPerPage < topCount) { + assertThat(minCompetitive.get(blockFactory()), nullValue()); + } else { + List> minCompetitiveRow = expectedValues.stream() + .sorted(comparator) + .skip(topCount - 1) + .findFirst() + .get(); + try (Page min = minCompetitive.get(blockFactory())) { + assertThat(min.getBlockCount(), equalTo(sortOrders.size())); + for (int s = 0; s < min.getBlockCount(); s++) { + logger.info("checking key {}", s); + TopNOperator.SortOrder sort = sortOrders.get(s); + Object actual = BlockUtils.toJavaObject(min.getBlock(s), 0); + Object expected = reduceKey(minCompetitiveRow.get(sort.channel()), sort.asc()); + assertThat(actual, equalTo(expected)); + } + } + } + } + } + operator.finish(); + assertThat(operator.needsInput(), equalTo(false)); + assertThat(operator.isFinished(), equalTo(false)); + + List>> actualValues = new ArrayList<>(); + while (operator.isFinished() == false) { + try (Page p = operator.getOutput()) { + readAsRows(actualValues, p); + assertThat(operator.needsInput(), equalTo(false)); + } + } + + List>> topNExpectedValues = expectedValues.stream().sorted(comparator).limit(topCount).toList(); + List> actualReducedValues = extractAndReduceSortedValues(actualValues, uniqueOrders); + List> expectedReducedValues = extractAndReduceSortedValues(topNExpectedValues, uniqueOrders); + + assertMap(actualReducedValues, matchesList(expectedReducedValues)); + } finally { + Releasables.close(minCompetitive); + } + } + + public void testStatus() { + BlockFactory blockFactory = driverContext().blockFactory(); + try (Operator op = simple(SimpleOptions.DEFAULT).get(driverContext())) { + Operator.Status status = op.status(); + assertThat(status, instanceOf(TopNOperatorStatus.class)); + TopNOperatorStatus topNStatus = (TopNOperatorStatus) status; + assertThat(topNStatus.occupiedRows(), equalTo(0)); + assertThat(topNStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(topNStatus.pagesReceived(), equalTo(0)); + assertThat(topNStatus.pagesEmitted(), equalTo(0)); + assertThat(topNStatus.rowsReceived(), equalTo(0L)); + assertThat(topNStatus.rowsEmitted(), equalTo(0L)); + + Page p = new Page(blockFactory.newConstantLongBlockWith(1, 10)); + op.addInput(p); + status = op.status(); + topNStatus = (TopNOperatorStatus) status; + assertThat(topNStatus.receiveNanos(), greaterThan(0L)); + assertThat(topNStatus.emitNanos(), equalTo(0L)); + assertThat(topNStatus.occupiedRows(), equalTo(4)); + assertThat(topNStatus.ramBytesUsed(), greaterThan(0L)); + assertThat(topNStatus.pagesReceived(), equalTo(1)); + assertThat(topNStatus.pagesEmitted(), equalTo(0)); + assertThat(topNStatus.rowsReceived(), equalTo(10L)); + assertThat(topNStatus.rowsEmitted(), equalTo(0L)); + } + } + private long randomJumboPageBytes() { return rarely() ? between(1, 100) : ByteSizeValue.ofKb(between(1, 1000)).getBytes(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java index 92cecdf79528c..0fbf656119310 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java @@ -18,12 +18,12 @@ public class TopNRowTests extends ESTestCase { private final CircuitBreaker breaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST); public void testRamBytesUsedEmpty() { - TopNOperator.Row row = new TopNOperator.Row(breaker, 0, 0); + TopNRow row = new TopNRow(breaker, 0, 0); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } public void testRamBytesUsedSmall() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); row.keys.append(randomByte()); row.values.append(randomByte()); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); @@ -35,7 +35,7 @@ public void testRamBytesUsedSmall() { * size estimates from previous rows. */ public void testFromHeapDump1() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 56, 24); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 56, 24); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); // 304 was measured debugging a heap dump and we've since shrunk assertThat(row.ramBytesUsed(), equalTo(240L)); @@ -47,14 +47,14 @@ public void testFromHeapDump1() { * size estimates from previous rows. */ public void testFromHeapDump2() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 1160, 1_153_096); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 1160, 1_153_096); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); // 1,154,464 is measured debugging a heap dump and we've since shrunk assertThat(row.ramBytesUsed(), equalTo(1_154_416L)); } public void testRamBytesUsedBig() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 0, 0); for (int i = 0; i < 10000; i++) { row.keys.append(randomByte()); row.values.append(randomByte()); @@ -63,11 +63,11 @@ public void testRamBytesUsedBig() { } public void testRamBytesUsedPreAllocated() { - TopNOperator.Row row = new TopNOperator.Row(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 64, 128); + TopNRow row = new TopNRow(new NoopCircuitBreaker(CircuitBreaker.REQUEST), 64, 128); assertThat(row.ramBytesUsed(), equalTo(expectedRamBytesUsed(row))); } - private long expectedRamBytesUsed(TopNOperator.Row row) { + private long expectedRamBytesUsed(TopNRow row) { long expected = RamUsageTester.ramUsed(row); if (row.values.bytes().length == 0) { // We double count the shared empty array for empty rows. This overcounting is *fine*, but throws off the test. diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java new file mode 100644 index 0000000000000..281aa91846ae4 --- /dev/null +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/AbstractTypedBlockSourceOperator.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.test; + +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.test.operator.blocksource.AbstractBlockSourceOperator; + +import java.util.List; + +public abstract class AbstractTypedBlockSourceOperator extends AbstractBlockSourceOperator { + protected AbstractTypedBlockSourceOperator(BlockFactory blockFactory, int maxPagePositions) { + super(blockFactory, maxPagePositions); + } + + public abstract List elementTypes(); +} diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java index d06f8ea4ba44a..43612dc985c1f 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestBlockBuilder.java @@ -10,13 +10,17 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.lucene.IndexedByShardId; +import org.elasticsearch.core.RefCounted; import java.util.List; @@ -441,4 +445,68 @@ public void close() { builder.close(); } } + + public static class DocBlockBuilder extends TestBlockBuilder { + private final DocBlock.Builder builder; + private final IndexedByShardId shardRefCounters; + + public DocBlockBuilder(BlockFactory blockFactory, IndexedByShardId shardRefCounters) { + this.shardRefCounters = shardRefCounters; + this.builder = DocBlock.newBlockBuilder(blockFactory, 0); + } + + @Override + public TestBlockBuilder appendObject(Object object) { + var doc = (BlockUtils.Doc) object; + builder.appendDoc(doc.doc()); + builder.appendSegment(doc.segment()); + builder.appendShard(doc.shard()); + return this; + } + + @Override + public TestBlockBuilder appendNull() { + builder.appendNull(); + return this; + } + + @Override + public TestBlockBuilder beginPositionEntry() { + builder.beginPositionEntry(); + return this; + } + + @Override + public TestBlockBuilder endPositionEntry() { + builder.endPositionEntry(); + return this; + } + + @Override + public Block.Builder copyFrom(Block block, int beginInclusive, int endExclusive) { + builder.copyFrom(block, beginInclusive, endExclusive); + return this; + } + + @Override + public Block.Builder mvOrdering(Block.MvOrdering mvOrdering) { + builder.mvOrdering(mvOrdering); + return this; + } + + @Override + public long estimatedBytes() { + return builder.estimatedBytes(); + } + + @Override + public Block build() { + return builder.shardRefCounters(shardRefCounters).build(); + } + + @Override + public void close() { + builder.close(); + } + } } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java index 33de84953e690..481cf9c877896 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/ListRowsBlockSourceOperator.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.compute.test.TestBlockBuilder; import org.elasticsearch.core.Releasables; @@ -22,7 +23,7 @@ /** * A source operator whose output is rows specified as a list {@link List} values. */ -public class ListRowsBlockSourceOperator extends AbstractBlockSourceOperator { +public class ListRowsBlockSourceOperator extends AbstractTypedBlockSourceOperator { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List types; @@ -39,7 +40,7 @@ protected Page createPage(int positionOffset, int length) { TestBlockBuilder[] blocks = new TestBlockBuilder[types.size()]; try { for (int b = 0; b < blocks.length; b++) { - blocks[b] = TestBlockBuilder.builderOf(blockFactory, types.get(b)); + blocks[b] = getTestBlockBuilder(b); } for (int i = 0; i < length; i++) { List row = rows.get(positionOffset + i); @@ -69,11 +70,16 @@ protected Page createPage(int positionOffset, int length) { } } + protected TestBlockBuilder getTestBlockBuilder(int b) { + return TestBlockBuilder.builderOf(blockFactory, types.get(b)); + } + @Override protected int remaining() { return rows.size() - currentPosition; } + @Override public List elementTypes() { return types; } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java index f15fec7b0e8d8..40b7e0ac4b7be 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/operator/blocksource/TupleAbstractBlockSourceOperator.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.AbstractTypedBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -19,7 +20,7 @@ * A source operator whose output is the given tuple values. This operator produces pages * with two Blocks. The returned pages preserve the order of values as given in the in initial list. */ -public abstract class TupleAbstractBlockSourceOperator extends AbstractBlockSourceOperator { +public abstract class TupleAbstractBlockSourceOperator extends AbstractTypedBlockSourceOperator { private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; private final List> values; @@ -86,6 +87,7 @@ protected int remaining() { return values.size() - currentPosition; } + @Override public List elementTypes() { return List.of(firstElementType, secondElementType); }