From efd20d0a789868cc3635316bfed9b2dc2b765aa4 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 16 Jul 2025 17:41:02 -0700 Subject: [PATCH 1/7] Add optimized path for intermediate values aggregator --- .../operator/ValuesAggregatorBenchmark.java | 29 +++-- .../gen/GroupingAggregatorImplementer.java | 119 ++++++++++-------- .../aggregation/ValuesBytesRefAggregator.java | 11 +- ...uesBytesRefGroupingAggregatorFunction.java | 6 +- .../ValuesBytesRefAggregators.java | 119 +++++++++++++----- .../aggregation/X-ValuesAggregator.java.st | 16 +-- 6 files changed, 186 insertions(+), 114 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index dfd56996e1c15..4bd33f2c3896a 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -113,16 +113,16 @@ static void selfTest() { @Param({ BYTES_REF, INT, LONG }) public String dataType; - private static Operator operator(DriverContext driverContext, int groups, String dataType) { + private static Operator operator(DriverContext driverContext, int groups, String dataType, AggregatorMode mode) { if (groups == 1) { return new AggregationOperator( - List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), + List.of(supplier(dataType).aggregatorFactory(mode, List.of(0)).apply(driverContext)), driverContext ); } List groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); return new HashAggregationOperator( - List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), + List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), driverContext ) { @@ -177,6 +177,9 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag // Check them BytesRefBlock values = page.getBlock(1); + if (values.asOrdinals() == null) { + throw new AssertionError(" expected ordinals; but got " + values); + } for (int p = 0; p < groups; p++) { checkExpectedBytesRef(prefix, values, p, expected.get(p)); } @@ -341,13 +344,21 @@ public void run() { private static void run(int groups, String dataType, int opCount) { DriverContext driverContext = driverContext(); - try (Operator operator = operator(driverContext, groups, dataType)) { - Page page = page(groups, dataType); - for (int i = 0; i < opCount; i++) { - operator.addInput(page.shallowCopy()); + try (Operator finalAggregator = operator(driverContext, groups, dataType, AggregatorMode.FINAL)) { + try (Operator initialAggregator = operator(driverContext, groups, dataType, AggregatorMode.INITIAL)) { + Page rawPage = page(groups, dataType); + for (int i = 0; i < opCount; i++) { + initialAggregator.addInput(rawPage.shallowCopy()); + } + initialAggregator.finish(); + Page intermediatePage = initialAggregator.getOutput(); + for (int i = 0; i < opCount; i++) { + finalAggregator.addInput(intermediatePage.shallowCopy()); + } + finalAggregator.finish(); + Page outputPage = finalAggregator.getOutput(); + checkExpected(groups, dataType, outputPage); } - operator.finish(); - checkExpected(groups, dataType, operator.getOutput()); } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index b82ea84cd0766..ad24c228f0d0f 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -609,63 +609,80 @@ private MethodSpec addIntermediateInput() { .collect(joining(" && ")) ); } - if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); - } - builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); - { - builder.addStatement("int groupId = groups.getInt(groupPosition)"); - if (aggState.declaredType().isPrimitive()) { - if (warnExceptions.isEmpty()) { - assert intermediateState.size() == 2; - assert intermediateState.get(1).name().equals("seen"); - builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); - } else { - assert intermediateState.size() == 3; - assert intermediateState.get(1).name().equals("seen"); - assert intermediateState.get(2).name().equals("failed"); - builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); - { - builder.addStatement("state.setFailed(groupId)"); + var bulkCombineIntermediateMethod = optionalStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.of( + Stream.of(aggState.declaredType(), TypeName.INT, INT_VECTOR), // aggState, positionOffset, groupIds + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) + ); + if (bulkCombineIntermediateMethod != null) { + var states = intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::name).collect(Collectors.joining(",")); + builder.addStatement("$T.combineIntermediate(state, positionOffset, groups," + states + ")", declarationType); + } else { + if (intermediateState.stream() + .map(AggregatorImplementer.IntermediateStateDesc::elementType) + .anyMatch(n -> n.equals("BYTES_REF"))) { + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); + } + builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); + { + builder.addStatement("int groupId = groups.getInt(groupPosition)"); + if (aggState.declaredType().isPrimitive()) { + if (warnExceptions.isEmpty()) { + assert intermediateState.size() == 2; + assert intermediateState.get(1).name().equals("seen"); + builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); + } else { + assert intermediateState.size() == 3; + assert intermediateState.get(1).name().equals("seen"); + assert intermediateState.get(2).name().equals("failed"); + builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); + { + builder.addStatement("state.setFailed(groupId)"); + } + builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); } - builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); - } - warningsBlock(builder, () -> { - var name = intermediateState.get(0).name(); - var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + warningsBlock(builder, () -> { + var name = intermediateState.get(0).name(); + var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + declarationType, + name, + vectorAccessor + ); + }); + builder.endControlFlow(); + } else { + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); + requireStaticMethod( declarationType, - name, - vectorAccessor + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.of( + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) + ); + builder.addStatement( + "$T.combineIntermediate(state, groupId, " + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) + + (stateHasBlock ? ", groupPosition + positionOffset" : "") + + ")", + declarationType ); - }); + } builder.endControlFlow(); - } else { - var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); - requireStaticMethod( - declarationType, - requireVoidType(), - requireName("combineIntermediate"), - requireArgs( - Stream.of( - Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId - intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), - Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position - ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) - ) - ); - - builder.addStatement( - "$T.combineIntermediate(state, groupId, " - + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) - + (stateHasBlock ? ", groupPosition + positionOffset" : "") - + ")", - declarationType - ); } - builder.endControlFlow(); } return builder.build(); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index cb0dff8a86dc5..bf4772048ebfa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -80,13 +80,8 @@ public static void combine(GroupingState state, int groupId, BytesRef v) { state.addValue(groupId, v); } - public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { - BytesRef scratch = new BytesRef(); - int start = values.getFirstValueIndex(valuesPosition); - int end = start + values.getValueCount(valuesPosition); - for (int i = start; i < end; i++) { - state.addValue(groupId, values.getBytesRef(i, scratch)); - } + public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, BytesRefBlock values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { @@ -216,7 +211,7 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } try (var sorted = buildSorted(selected)) { - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(bytes.size()))) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index b73a017eacc7c..9ada159cb49de 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -214,11 +214,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups,values); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 4a2fa0923abe4..1ada5241e7a94 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -28,15 +28,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput( if (valuesOrdinal == null) { return delegate; } - BytesRefVector dict = valuesOrdinal.getDictionaryVector(); - final IntVector hashIds; - BytesRef spare = new BytesRef(); - try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { - for (int p = 0; p < dict.getPositionCount(); p++) { - hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); - } - hashIds = hashIdsBuilder.build(); - } + final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector()); IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock(); return new GroupingAggregatorFunction.AddInput() { @Override @@ -85,17 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { @Override public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int groupId = groupIds.getInt(groupPosition); - if (ordinalIds.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v))); - } - } + addOrdinalInputBlock(state, positionOffset, groupIds, ordinalIds, hashIds); } @Override @@ -114,15 +96,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput( if (valuesOrdinal == null) { return delegate; } - BytesRefVector dict = valuesOrdinal.getDictionaryVector(); - final IntVector hashIds; - BytesRef spare = new BytesRef(); - try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { - for (int p = 0; p < dict.getPositionCount(); p++) { - hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); - } - hashIds = hashIdsBuilder.build(); - } + final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector()); var ordinalIds = valuesOrdinal.getOrdinalsVector(); return new GroupingAggregatorFunction.AddInput() { @Override @@ -157,10 +131,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { @Override public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int groupId = groupIds.getInt(groupPosition); - state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); - } + addOrdinalInputVector(state, positionOffset, groupIds, ordinalIds, hashIds); } @Override @@ -169,4 +140,86 @@ public void close() { } }; } + + static IntVector hashDict(ValuesBytesRefAggregator.GroupingState state, BytesRefVector dict) { + BytesRef scratch = new BytesRef(); + try (var hashIdsBuilder = dict.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + final long hashId = BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, scratch))); + hashIdsBuilder.appendInt(Math.toIntExact(hashId)); + } + return hashIdsBuilder.build(); + } + } + + static void addOrdinalInputBlock( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + IntBlock ordinalIds, + IntVector hashIds + ) { + for (int p = 0; p < groupIds.getPositionCount(); p++) { + final int valuePosition = p + positionOffset; + final int groupId = groupIds.getInt(valuePosition); + final int start = ordinalIds.getFirstValueIndex(valuePosition); + final int end = start + ordinalIds.getValueCount(valuePosition); + for (int i = start; i < end; i++) { + int ord = ordinalIds.getInt(i); + state.addValueOrdinal(groupId, hashIds.getInt(ord)); + } + } + } + + static void addOrdinalInputVector( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + IntVector ordinalIds, + IntVector hashIds + ) { + for (int p = 0; p < groupIds.getPositionCount(); p++) { + int groupId = groupIds.getInt(p); + int ord = ordinalIds.getInt(p + positionOffset); + state.addValueOrdinal(groupId, hashIds.getInt(ord)); + } + } + + static void combineIntermediateInputValues( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + BytesRefBlock values + ) { + BytesRefVector dict = null; + IntBlock ordinals = null; + { + final OrdinalBytesRefBlock asOrdinals = values.asOrdinals(); + if (asOrdinals != null) { + dict = asOrdinals.getDictionaryVector(); + ordinals = asOrdinals.getOrdinalsBlock(); + } + } + if (dict != null && dict.getPositionCount() < groupIds.getPositionCount()) { + try (var hashIds = hashDict(state, dict)) { + IntVector ordinalsVector = ordinals.asVector(); + if (ordinalsVector != null) { + addOrdinalInputVector(state, positionOffset, groupIds, ordinalsVector, hashIds); + } else { + addOrdinalInputBlock(state, positionOffset, groupIds, ordinals, hashIds); + } + } + } else { + final BytesRef scratch = new BytesRef(); + for (int p = 0; p < groupIds.getPositionCount(); p++) { + final int valuePosition = p + positionOffset; + final int groupId = groupIds.getInt(valuePosition); + final int start = values.getFirstValueIndex(valuePosition); + final int end = start + values.getValueCount(valuePosition); + for (int i = start; i < end; i++) { + state.addValue(groupId, values.getBytesRef(i, scratch)); + } + } + } + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index fa8ffecea052d..8cb263da20813 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -113,20 +113,20 @@ $endif$ state.addValue(groupId, v); } - public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ + public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, $Type$Block values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + +$else$ + public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { -$if(BytesRef)$ - state.addValue(groupId, values.getBytesRef(i, scratch)); -$else$ state.addValue(groupId, values.get$Type$(i)); -$endif$ } } +$endif$ public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { if (statePosition > state.maxGroupId) { @@ -327,7 +327,7 @@ $endif$ try (var sorted = buildSorted(selected)) { $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(bytes.size()))) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); From 747d56e85bdba72aa0c1d60a2ddddd83b5a0260d Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 16 Jul 2025 22:01:45 -0700 Subject: [PATCH 2/7] Update docs/changelog/131390.yaml --- docs/changelog/131390.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/131390.yaml diff --git a/docs/changelog/131390.yaml b/docs/changelog/131390.yaml new file mode 100644 index 0000000000000..849adcb1a173a --- /dev/null +++ b/docs/changelog/131390.yaml @@ -0,0 +1,5 @@ +pr: 131390 +summary: Add optimized path for intermediate values aggregator +area: ES|QL +type: enhancement +issues: [] From a25e26e1fa78fa29ae42dcba6fa8c0af1b0f0b99 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 16 Jul 2025 22:21:47 -0700 Subject: [PATCH 3/7] extra check --- .../compute/operator/ValuesAggregatorBenchmark.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 4bd33f2c3896a..def9c58160002 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -355,10 +355,10 @@ private static void run(int groups, String dataType, int opCount) { for (int i = 0; i < opCount; i++) { finalAggregator.addInput(intermediatePage.shallowCopy()); } - finalAggregator.finish(); - Page outputPage = finalAggregator.getOutput(); - checkExpected(groups, dataType, outputPage); } + finalAggregator.finish(); + Page outputPage = finalAggregator.getOutput(); + checkExpected(groups, dataType, outputPage); } } From 87bc0503df775ac4a77ab5ecfa491f245bd205b2 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 17 Jul 2025 11:44:04 -0700 Subject: [PATCH 4/7] fix dense --- .../compute/aggregation/ValuesBytesRefAggregator.java | 2 +- .../compute/aggregation/X-ValuesAggregator.java.st | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index cb6c4c7be3392..5d6111d0c172a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -194,7 +194,7 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } try (var sorted = buildSorted(selected)) { - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(bytes.size()))) { + if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 5364d50559377..0df843f5e4273 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -304,7 +304,7 @@ $endif$ try (var sorted = buildSorted(selected)) { $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(bytes.size()))) { + if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); From 8df8fb8bf40cb69e6ff91f3cb622d7f77286b10b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 21 Jul 2025 09:16:00 -0700 Subject: [PATCH 5/7] format --- .../compute/gen/GroupingAggregatorImplementer.java | 6 ++++-- .../ValuesBytesRefGroupingAggregatorFunction.java | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index a32614fb6d6c2..3c67aa09ac988 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -620,8 +620,10 @@ private MethodSpec addIntermediateInput() { ) ); if (bulkCombineIntermediateMethod != null) { - var states = intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::name).collect(Collectors.joining(",")); - builder.addStatement("$T.combineIntermediate(state, positionOffset, groups," + states + ")", declarationType); + var states = intermediateState.stream() + .map(AggregatorImplementer.IntermediateStateDesc::name) + .collect(Collectors.joining(", ")); + builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType); } else { if (intermediateState.stream() .map(AggregatorImplementer.IntermediateStateDesc::elementType) diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index a8ce79b296d3f..af9369da60b34 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -214,7 +214,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups,values); + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } @Override From faf02520d9962d3e6c98075274dd0ba2b55707fe Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 21 Jul 2025 09:19:42 -0700 Subject: [PATCH 6/7] streams --- .../compute/gen/GroupingAggregatorImplementer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 3c67aa09ac988..6ce7e554b6c57 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -613,10 +613,10 @@ private MethodSpec addIntermediateInput() { requireVoidType(), requireName("combineIntermediate"), requireArgs( - Stream.of( + Stream.concat( Stream.of(aggState.declaredType(), TypeName.INT, INT_VECTOR), // aggState, positionOffset, groupIds intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) - ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) ) ); if (bulkCombineIntermediateMethod != null) { From 684592804671de0ba97724d250667ff2201f7b02 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 21 Jul 2025 10:19:13 -0700 Subject: [PATCH 7/7] fix after merges --- .../gen/GroupingAggregatorImplementer.java | 8 ++++-- .../aggregation/ValuesBytesRefAggregator.java | 4 +++ ...uesBytesRefGroupingAggregatorFunction.java | 26 ++--------------- .../ValuesBytesRefAggregators.java | 28 +++++++++++++++++++ .../aggregation/X-ValuesAggregator.java.st | 4 +++ 5 files changed, 44 insertions(+), 26 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index c03ad8058375e..6042915f70aee 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -58,6 +58,7 @@ import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK; import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK; +import static org.elasticsearch.compute.gen.Types.INT_BLOCK; import static org.elasticsearch.compute.gen.Types.INT_VECTOR; import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; @@ -615,7 +616,8 @@ private MethodSpec addIntermediateInput(TypeName groupsType) { requireName("combineIntermediate"), requireArgs( Stream.concat( - Stream.of(aggState.declaredType(), TypeName.INT, INT_VECTOR), // aggState, positionOffset, groupIds + // aggState, positionOffset, groupIds + Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR), intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) ) @@ -626,7 +628,9 @@ private MethodSpec addIntermediateInput(TypeName groupsType) { .collect(Collectors.joining(", ")); builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType); } else { - if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { + if (intermediateState.stream() + .map(AggregatorImplementer.IntermediateStateDesc::elementType) + .anyMatch(n -> n.equals("BYTES_REF"))) { builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); } builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 5d6111d0c172a..79077b6628105 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -84,6 +84,10 @@ public static void combineIntermediate(GroupingState state, int positionOffset, ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); } + public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, BytesRefBlock values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 7868034555265..b76f52d335a03 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -152,18 +152,7 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { @@ -209,18 +198,7 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 1ada5241e7a94..20ec7b08f9cb7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -222,4 +222,32 @@ static void combineIntermediateInputValues( } } } + + static void combineIntermediateInputValues( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntBlock groupIds, + BytesRefBlock values + ) { + final BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupId = groupIds.getInt(g); + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + var bytes = values.getBytesRef(v, scratch); + state.addValue(groupId, bytes); + } + } + } + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 0df843f5e4273..d92ac5fa0afce 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -118,6 +118,10 @@ $if(BytesRef)$ ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); } + public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, $Type$Block values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + $else$ public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition);