diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java index 49d26dbb61aca..348c3b28c2d0f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java @@ -56,6 +56,12 @@ void set(int groupId, double value) { trackGroupId(groupId); } + void increment(int groupId, double value) { + ensureCapacity(groupId); + values.increment(groupId, value); + trackGroupId(groupId); + } + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { if (false == trackingGroupIds()) { try (var builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(selected.getPositionCount())) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateAggregatorFunction.java new file mode 100644 index 0000000000000..ed19391fa1995 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateAggregatorFunction.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +import java.util.List; + +public class CountApproximateAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.DOUBLE), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) + ); + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + private final DoubleState state; + private final List channels; + private final boolean countAll; + + public static CountApproximateAggregatorFunction create(List inputChannels) { + return new CountApproximateAggregatorFunction(inputChannels, new DoubleState(0)); + } + + protected CountApproximateAggregatorFunction(List channels, DoubleState state) { + this.channels = channels; + this.state = state; + // no channels specified means count-all/count(*) + this.countAll = channels.isEmpty(); + } + + @Override + public int intermediateBlockCount() { + return intermediateStateDesc().size(); + } + + private int blockIndex() { + // In case of countAll, block index is irrelevant. + // Page.positionCount should be used instead, + // because the page could have zero blocks + // (drop all columns scenario) + return countAll ? -1 : channels.get(0); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (countAll) { + // this will work also when the page has no blocks + if (mask.isConstant() && mask.getBoolean(0)) { + state.doubleValue(state.doubleValue() + page.getPositionCount()); + } else { + int count = 0; + for (int i = 0; i < mask.getPositionCount(); i++) { + if (mask.getBoolean(i)) { + count++; + } + } + state.doubleValue(state.doubleValue() + count); + } + } else { + Block block = page.getBlock(blockIndex()); + DoubleState state = this.state; + int count; + if (mask.isConstant()) { + if (mask.getBoolean(0) == false) { + return; + } + count = getBlockTotalValueCount(block); + } else { + count = countMasked(block, mask); + } + state.doubleValue(state.doubleValue() + count); + } + } + + /** + * Returns the number of total values in a block + * @param block block to count values for + * @return number of total values present in the block + */ + protected int getBlockTotalValueCount(Block block) { + return block.getTotalValueCount(); + } + + private int countMasked(Block block, BooleanVector mask) { + int count = 0; + if (countAll) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p)) { + count++; + } + } + return count; + } + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p)) { + count += getBlockValueCountAtPosition(block, p); + } + } + return count; + } + + /** + * Returns the number of values at a given position in a block + * @param block block + * @param position position to get the number of values + * @return + */ + protected int getBlockValueCountAtPosition(Block block, int position) { + return block.getValueCount(position); + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + var blockIndex = blockIndex(); + assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size(); + Block uncastBlock = page.getBlock(channels.get(0)); + if (uncastBlock.areAllValuesNull()) { + return; + } + DoubleVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == 1; + assert count.getPositionCount() == seen.getPositionCount(); + state.doubleValue(state.doubleValue() + count.getDouble(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } + + public static AggregatorFunctionSupplier supplier() { + return new CountApproximateAggregatorFunctionSupplier(); + } + + protected static class CountApproximateAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + @Override + public List nonGroupingIntermediateStateDesc() { + return CountApproximateAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return CountApproximateGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + return CountApproximateAggregatorFunction.create(channels); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return CountApproximateGroupingAggregatorFunction.create(driverContext, channels); + } + + @Override + public String describe() { + return "count"; + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..10caa4a5a11c2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountApproximateGroupingAggregatorFunction.java @@ -0,0 +1,307 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; +import org.elasticsearch.compute.operator.DriverContext; + +import java.util.List; + +public class CountApproximateGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.DOUBLE), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) + ); + + private final DoubleArrayState state; + private final List channels; + private final DriverContext driverContext; + private final boolean countAll; + + public static CountApproximateGroupingAggregatorFunction create(DriverContext driverContext, List inputChannels) { + return new CountApproximateGroupingAggregatorFunction( + inputChannels, + new DoubleArrayState(driverContext.bigArrays(), 0), + driverContext + ); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + protected CountApproximateGroupingAggregatorFunction(List channels, DoubleArrayState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.countAll = channels.isEmpty(); + } + + private int blockIndex() { + return countAll ? 0 : channels.get(0); + } + + @Override + public int intermediateBlockCount() { + return intermediateStateDesc().size(); + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + Block valuesBlock = page.getBlock(blockIndex()); + if (countAll == false) { + Vector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() {} + }; + } + } + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(groupIds); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(groupIds); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(groupIds); + } + + @Override + public void close() {} + }; + } + + private void addRawInput(int positionOffset, IntVector groups, Block values) { + int position = positionOffset; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { + if (values.isNull(position)) { + continue; + } + int groupId = groups.getInt(groupPosition); + state.increment(groupId, getBlockValueCountAtPosition(values, position)); + } + } + + /** + * Returns the number of values at a given position in a block + * @param values block + * @param position position to get the number of values + * @return + */ + protected int getBlockValueCountAtPosition(Block values, int position) { + return values.getValueCount(position); + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) { + int position = positionOffset; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { + if (groups.isNull(groupPosition) || values.isNull(position)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, getBlockValueCountAtPosition(values, position)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block values) { + int position = positionOffset; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { + if (groups.isNull(groupPosition) || values.isNull(position)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, getBlockValueCountAtPosition(values, position)); + } + } + } + + /** + * This method is called for count all. + */ + private void addRawInput(IntVector groups) { + if (groups.isConstant()) { + state.increment(groups.getInt(0), groups.getPositionCount()); + } else { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + state.increment(groupId, 1); + } + } + } + + /** + * This method is called for count all. + */ + private void addRawInput(IntArrayBlock groups) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, 1); + } + } + } + + /** + * This method is called for count all. + */ + private void addRawInput(IntBigArrayBlock groups) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, 1); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // no need to track seen groups, as count returns 0 for groups without values. + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + DoubleVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getDouble(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + DoubleVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getDouble(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + DoubleVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + state.increment(groups.getInt(groupPosition), count.getDouble(groupPosition + positionOffset)); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + try (var values = driverContext.blockFactory().newDoubleVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int si = selected.getInt(i); + values.appendDouble(state.getOrDefault(si)); + } + blocks[offset] = values.build().asBlock(); + // Unlike other aggregations, we return 0 for groups without values instead of null. + // Therefore, we can always return true for seen, and do not need to track seen groups. + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(true, selected.getPositionCount()); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + try (DoubleVector.Builder builder = evaluationContext.blockFactory().newDoubleVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int si = selected.getInt(i); + builder.appendDouble(state.getOrDefault(si)); + } + blocks[offset] = builder.build().asBlock(); + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st index e49d6fcfad641..18dbda9fb5775 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st @@ -83,8 +83,8 @@ $endif$ trackGroupId(groupId); } -$if(long)$ - void increment(int groupId, long value) { +$if(long||double)$ + void increment(int groupId, $type$ value) { ensureCapacity(groupId); values.increment(groupId, value); trackGroupId(groupId); diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java index 4d1687656c62f..b9646bb3d15f3 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeForkRestTest.java @@ -15,7 +15,7 @@ import java.util.List; import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues; -import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.APPROXIMATION_V2; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.APPROXIMATION_V3; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.FORK_V9; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_GROUP_BY_ALL; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SUBQUERY_IN_FROM_COMMAND; @@ -79,7 +79,7 @@ protected void shouldSkipTest(String testName) throws IOException { assumeFalse( "Tests using query approximation are skipped since query approximation is not supported with FORK", - testCase.requiredCapabilities.contains(APPROXIMATION_V2.capabilityName()) + testCase.requiredCapabilities.contains(APPROXIMATION_V3.capabilityName()) ); assumeFalse( diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/approximation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/approximation.csv-spec index eef8df736937d..d8d2ce37be8df 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/approximation.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/approximation.csv-spec @@ -5,7 +5,7 @@ // - sv: a single-valued field containing the values 1x1, 2x2, ..., 500x500 // (hence the 125,250 rows) // -// - mv: a multi-valued field containing the values 1x1, 2x2, ..., 1000x1000 +// - mv: a multi-valued field containing the values 1x1, 2x2, ..., 2000x2000 // (15-16 values per multi-valued field and 2,001,000 values in total) // // - sv_and_one_mv: same as "sv", except that one row has the field [x,0] instead @@ -21,8 +21,20 @@ // which would make the tests flaky. +No approximation +required_capability: approximation_v3 + +SET approximation=false\; +FROM many_numbers | STATS count=COUNT(*) +; + +count:long +125250 +; + + Exact total row count -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers | STATS count=COUNT(*) @@ -34,7 +46,7 @@ count:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean Exact total single-valued field count -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers | STATS count=COUNT(sv) @@ -46,7 +58,7 @@ count:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean Approximate total multi-valued field count -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers | STATS count=COUNT(mv) @@ -58,11 +70,10 @@ count:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolea Approximate stats on large single-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; -FROM many_numbers - | STATS count=COUNT(sv), avg=AVG(sv), sum=SUM(sv) +FROM many_numbers | STATS count=COUNT(sv), avg=AVG(sv), sum=SUM(sv) ; count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean @@ -71,10 +82,11 @@ count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):lo Approximate stats on large multi-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation=true\; FROM many_numbers + | MV_EXPAND mv | STATS count=COUNT(mv), avg=AVG(mv), sum=SUM(mv) ; @@ -84,7 +96,7 @@ count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(cou Exact stats on small single-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -98,7 +110,7 @@ count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED Exact stats on small multi-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -113,11 +125,10 @@ count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED Multiple total counts -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; -FROM many_numbers - | STATS count=COUNT(*), count2=COUNT("*"), countValue=COUNT(mv) +FROM many_numbers | STATS count=COUNT(*), count2=COUNT("*"), countValue=COUNT(mv) ; count:long | count2:long | countValue:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(count2):long| CERTIFIED(count2):boolean | CONFIDENCE_INTERVAL(countValue):long | CERTIFIED(countValue):boolean @@ -126,7 +137,7 @@ count:long | count2:long | countValue:long | CONFIDENCE_INTERVAL(count): Exact count with where on single-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -140,7 +151,7 @@ count:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean Approximate stats with where on multi-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -149,13 +160,13 @@ FROM many_numbers | STATS count=COUNT(), avg=AVG(mv), sum=SUM(mv) ; -count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean -1000000..2000000 | 1545..1565 | 1500000000..3000000000 | [1000000..2000000,1000000..2000000] | {any} | [1545..1565,1545..1565] | {any} | [1500000000..3000000000,1500000000..3000000000] | {any} +count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean +800000..2200000 | 1525..1585 | 1300000000..3300000000 | [800000..2200000,800000..2200000] | {any} | [1525..1585,1525..1585] | {any} | [1300000000..3300000000,1300000000..3300000000] | {any} ; Approximate stats with stats where -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000, "confidence_level":0.85}\; FROM many_numbers @@ -165,13 +176,13 @@ FROM many_numbers sum=SUM(mv) WHERE mv >= 1001 ; -count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean -1000000..2000000 | 1545..1565 | 1500000000..3000000000 | [1000000..2000000,1000000..2000000] | {any} | [1545..1565,1545..1565] | {any} | [1500000000..3000000000,1500000000..3000000000] | {any} +count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean +800000..2200000 | 1525..1585 | 1300000000..3300000000 | [800000..2200000,800000..2200000] | {any} | [1525..1585,1525..1585] | {any} | [1300000000..3300000000,1300000000..3300000000] | {any} ; Approximate stats with sample -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000,"confidence_level":0.85}\; FROM many_numbers @@ -180,12 +191,12 @@ FROM many_numbers ; count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(sum):long | CERTIFIED(sum):boolean -85000..115000 | 315..350 | 25000000..400000000 | [85000..115000,85000..115000] | {any} | [315..350,315..350] | {any} | [25000000..400000000,25000000..400000000] | {any} +70000..130000 | 310..360 | 20000000..450000000 | [70000..130000,70000..130000] | {any} | [310..360,310..360] | {any} | [20000000..450000000,20000000..450000000] | {any} ; Approximate stats with commands before stats -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -209,7 +220,7 @@ count:long | avg:double | sum:long | CONFIDENCE_INTERVAL(cou Approximate stats with commands after stats -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -227,7 +238,7 @@ avg:double | avg2:double | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):bool Approximate stats with dependent variables that have confidence interval -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -246,7 +257,7 @@ y:integer | plus1:double | CONFIDENCE_INTERVAL(plus1):double | CERTIFIED(p Approximate stats with dependent string variable -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -262,7 +273,7 @@ from_str:double Approximate stats with dependent multi-valued variable -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -277,7 +288,7 @@ sv:double Approximate stats by with zero variance -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":100000}\; FROM many_numbers @@ -287,17 +298,17 @@ FROM many_numbers | LIMIT 5 ; -avg:double | median:double | one:double | mv:integer | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(median):double | CERTIFIED(median):boolean | CONFIDENCE_INTERVAL(one):double | CERTIFIED(one):boolean -2000 | 2000 | 1 | 2000 | [2000,2000] | {any} | [2000,2000] | {any} | [1,1] | {any} -1999 | 1999 | 1 | 1999 | [1999,1999] | {any} | [1999,1999] | {any} | [1,1] | {any} -1998 | 1998 | 1 | 1998 | [1998,1998] | {any} | [1998,1998] | {any} | [1,1] | {any} -1997 | 1997 | 1 | 1997 | [1997,1997] | {any} | [1997,1997] | {any} | [1,1] | {any} -1996 | 1996 | 1 | 1996 | [1996,1996] | {any} | [1996,1996] | {any} | [1,1] | {any} +avg:double | median:double | one:double | mv:integer | CONFIDENCE_INTERVAL(avg):double | CERTIFIED(avg):boolean | CONFIDENCE_INTERVAL(median):double | CERTIFIED(median):boolean | CONFIDENCE_INTERVAL(one):double | CERTIFIED(one):boolean +1999.999999..2000.000001 | 2000 | 1 | 2000 | [1999.999999..2000.000001,1999.999999..2000.000001] | {any} | [2000,2000] | true | [1,1] | true +1998.999999..1999.000001 | 1999 | 1 | 1999 | [1998.999999..1999.000001,1998.999999..1999.000001] | {any} | [1999,1999] | true | [1,1] | true +1997.999999..1998.000001 | 1998 | 1 | 1998 | [1997.999999..1998.000001,1997.999999..1998.000001] | {any} | [1998,1998] | true | [1,1] | true +1996.999999..1997.000001 | 1997 | 1 | 1997 | [1996.999999..1997.000001,1996.999999..1997.000001] | {any} | [1997,1997] | true | [1,1] | true +1995.999999..1996.000001 | 1996 | 1 | 1996 | [1995.999999..1996.000001,1995.999999..1996.000001] | {any} | [1996,1996] | true | [1,1] | true ; Approximate stats by on large single-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -316,7 +327,7 @@ count:long | sv:integer | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boo Approximate stats by on large multi-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":100000}\; FROM many_numbers @@ -336,7 +347,7 @@ count:long | mv:integer | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boo Exact stats by on small single-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -355,7 +366,7 @@ count:long | sv:integer | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boo Exact stats by on small multi-valued data -required_capability: approximation_v2 +required_capability: approximation_v3 SET approximation={"rows":10000}\; FROM many_numbers @@ -374,18 +385,11 @@ count:long | mv:integer | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boo ; -// This test fails in a multi-node setup: on one node, the field "sv_and_one_mv" -// is multi-valued, and on the other nodes it is single-valued. This leads to -// wrong sample correction on the coordinator node. -// TODO: fix this issue by moving the sample correction logic to the data nodes, -// and re-enable this test. - -Approximate stats by on mixed data-Ignore -required_capability: approximation_v2 +Approximate stats by on mixed data +required_capability: approximation_v3 SET approximation={"rows":10000}\; -FROM many_numbers - | STATS count=COUNT(sv_and_one_mv) +FROM many_numbers | STATS count=COUNT(sv_and_one_mv) ; count:long | CONFIDENCE_INTERVAL(count):long | CERTIFIED(count):boolean diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index cf4d343807388..ca05ae070a42b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -2086,7 +2086,7 @@ public enum Cap { /** * Support query approximation. */ - APPROXIMATION_V2(Build.current().isSnapshot()), + APPROXIMATION_V3(Build.current().isSnapshot()), /** * Create a ScoreOperator only when shard contexts are available diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/Approximation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/Approximation.java index 7973cea503fb1..e985b83d6c118 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/Approximation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/Approximation.java @@ -9,6 +9,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.index.IndexMode; import org.elasticsearch.logging.LogManager; @@ -16,6 +17,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -24,6 +26,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.CountApproximate; import org.elasticsearch.xpack.esql.expression.function.aggregate.Median; import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; @@ -213,6 +216,7 @@ public record QueryProperties(boolean hasGrouping, boolean canDecreaseRowCount, static final Set> SUPPORTED_SINGLE_VALUED_AGGS = Set.of( Avg.class, Count.class, + CountApproximate.class, Median.class, MedianAbsoluteDeviation.class, Percentile.class, @@ -258,7 +262,9 @@ public record QueryProperties(boolean hasGrouping, boolean canDecreaseRowCount, private static final Logger logger = LogManager.getLogger(Approximation.class); - private static final AggregateFunction COUNT_ALL_ROWS = new Count(Source.EMPTY, Literal.keyword(Source.EMPTY, StringUtils.WILDCARD)); + private static final Expression WILDCARD = Literal.keyword(Source.EMPTY, StringUtils.WILDCARD); + private static final AggregateFunction COUNT_ALL_ROWS_EXACT = new Count(Source.EMPTY, WILDCARD); + private static final AggregateFunction COUNT_ALL_ROWS_APPROXIMATE = new CountApproximate(Source.EMPTY, WILDCARD); private final LogicalPlan logicalPlan; private final ApproximationSettings settings; @@ -378,7 +384,7 @@ private LogicalPlan sourceCountSubPlan() { Source.EMPTY, leaf, List.of(), - List.of(new Alias(Source.EMPTY, "$source_count", COUNT_ALL_ROWS)) + List.of(new Alias(Source.EMPTY, "$source_count", COUNT_ALL_ROWS_EXACT)) ); sourceCountPlan.setOptimized(); return sourceCountPlan; @@ -430,10 +436,13 @@ private LogicalPlan countSubPlan(double sampleProbability) { if (plan instanceof Aggregate aggregate) { // The STATS function should be replaced by a STATS COUNT(*). encounteredStats.set(true); - List aggregations = List.of(new Alias(Source.EMPTY, "$count_p=" + sampleProbability, COUNT_ALL_ROWS)); if (sampleProbability == 1.0) { + List aggregations = List.of(new Alias(Source.EMPTY, "$count_p=1", COUNT_ALL_ROWS_EXACT)); plan = new Aggregate(Source.EMPTY, aggregate.child(), List.of(), aggregations); } else { + List aggregations = List.of( + new Alias(Source.EMPTY, "$count_p=" + sampleProbability, COUNT_ALL_ROWS_APPROXIMATE) + ); plan = new SampledAggregate( Source.EMPTY, aggregate.child(), @@ -508,6 +517,9 @@ private LogicalPlan processCount(long rowCount) { throw new IllegalStateException("Approximation count iteration limit exceeded"); } double sampleProbability = nextSubPlanSampleProbability; + // The row count is sample-corrected, however here we want the actual + // (not-corrected) number of rows reaching the STATS. + rowCount = Math.round(sampleProbability * rowCount); logger.debug("estimated number of rows reaching STATS (p=[{}]): [{}] rows", sampleProbability, rowCount); double newSampleProbability = Math.min(1.0, sampleProbability * sampleRowCount() / Math.max(1, rowCount)); if (newSampleProbability > SAMPLE_PROBABILITY_THRESHOLD) { @@ -547,7 +559,11 @@ private long rowCount(Result countResult) { assert countResult.pages().getFirst().getBlockCount() == 1; assert countResult.pages().getFirst().getPositionCount() == 1; - long rowCount = ((LongBlock) (countResult.pages().getFirst().getBlock(0))).getLong(0); + long rowCount = switch (countResult.pages().getFirst().getBlock(0)) { + case DoubleBlock doubleBlock -> Math.round(doubleBlock.getDouble(0)); + case LongBlock longBlock -> longBlock.getLong(0); + default -> throw new IllegalStateException("Unexpected value: " + countResult.pages().getFirst().getBlock(0)); + }; countResult.pages().getFirst().close(); return rowCount; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/ApproximationPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/ApproximationPlan.java index 0113bf0abc99a..41266dde268c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/ApproximationPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximation/ApproximationPlan.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.LeafExpression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; @@ -26,6 +25,7 @@ import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.CountApproximate; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.ConfidenceInterval; @@ -36,12 +36,11 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSlice; -import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.optimizer.rules.logical.SubstituteApproximationPlan; @@ -84,7 +83,7 @@ public class ApproximationPlan { /** * The number of buckets to use for computing confidence intervals. */ - static final int BUCKET_COUNT = 16; + public static final int BUCKET_COUNT = 16; /** * Default confidence level for confidence intervals. @@ -101,13 +100,6 @@ public class ApproximationPlan { */ static final int MIN_ROW_COUNT_FOR_RESULT_INCLUSION = 10; - /** - * These aggregate functions need to be corrected for random sampling, by - * scaling up the sampled value by the inverse of the sampling probability. - * Other aggregate functions do not need any correction. - */ - private static final Set> SAMPLE_CORRECTED_AGGS = Set.of(Count.class, Sum.class); - /** * These numerical scalar functions produce multivalued output. This means that * confidence intervals cannot be computed anymore and are dropped. @@ -212,12 +204,12 @@ public int hashCode() { * | EVAL bucketId = MV_APPEND(RANDOM(B), ... , RANDOM(B)) // T times * | SAMPLED_STATS[SampleProbabilityPlaceHolder] * sampleSize = COUNT(*), - * s = SUM(x) / prob, - * `s$0` = SUM(x) / (prob/B)) WHERE MV_SLICE(bucketId, 0, 0) == 0 + * s = SUM(x), + * `s$0` = SUM(x) WHERE MV_SLICE(bucketId, 0, 0) == 0 * ..., - * `s$T*B-1` = SUM(x) / (prob/B) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1 + * `s$T*B-1` = SUM(x) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1 * BY group - * | WHERE sampleSize >= sampleSizeThreshold + * | WHERE sampleSize >= MIN_ROW_COUNT_FOR_RESULT_INCLUSION / prob * | EVAL t = s*s, `t$0` = `s$0`*`s$0`, ..., `t$T*B-1` = `s$T*B-1`*`s$T*B-1` * | EVAL `CONFIDENCE_INTERVAL(s)` = CONFIDENCE_INTERVAL(s, MV_APPEND(`s$0`, ... `s$T*B-1`), T, B, 0.90), * `CONFIDENCE_INTERVAL(t)` = CONFIDENCE_INTERVAL(t, MV_APPEND(`t$0`, ... `t$T*B-1`), T, B, 0.90) @@ -225,8 +217,8 @@ public int hashCode() { * } * * During execution the {@code SAMPLED_STATS} is replaced on the data node by either - * sampling the source rows and a normal {@code STATS}, or pushed down to Lucene without - * any sampling if that's possible (which would be more efficient if it is). + * sampling the source rows and a normal {@code STATS} (with sample corrections applied + * to intermediate state), or pushed down to Lucene without any sampling (if possible). */ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings settings) { logger.debug("generating approximation plan"); @@ -237,15 +229,14 @@ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings set // The keys are the IDs of the fields that have buckets. Confidence intervals are computed // for these fields at the end of the computation. They map to the list of buckets for // that field. - Map> fieldBuckets = new HashMap<>(); + Map> fieldBuckets = new HashMap<>(); - // For each sample-corrected expression, also keep track of the uncorrected expression. - // These are used when a division between two sample-corrected expressions is encountered. - // This results in the same value (because (expr1/prob) / (expr2/prob) == expr1/expr2), - // except that no round-off errors occur if either the numerator or denominator is an - // integer and rounded to that after sample-correction. The most common case is AVG, which + // For each rounded expression, also keep track of the not rounded expression. + // These are used when a division between two rounded expressions is encountered. + // This results in more accurate values, because no round-off errors occur in + // the numerator and denominator. The most common use case is AVG, which // is rewritten to AVG::double = SUM::double / COUNT::long. - Map uncorrectedExpressions = new HashMap<>(); + Map notRoundedExpressions = new HashMap<>(); LogicalPlan approximationPlan = logicalPlan.transformUp(plan -> { if (encounteredStats.get() == false) { @@ -253,16 +244,16 @@ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings set // Commands before the first STATS function should be left unchanged. return plan; } else { - // The first STATS function should be replaced by a sample-corrected STATS - // and buckets (for computing confidence intervals). + // The first STATS function should be replaced by a STATS with buckets + // (for computing confidence intervals). encounteredStats.set(true); - return sampleCorrectedAggregateAndBuckets((Aggregate) plan, fieldBuckets, uncorrectedExpressions); + return sampleCorrectedAggregateAndBuckets((Aggregate) plan, fieldBuckets, notRoundedExpressions); } } else { // After the STATS function, any processing of fields that have buckets, should // also process the buckets, so that confidence intervals for the dependent fields // can be computed. - return planIncludingBuckets(plan, fieldBuckets, uncorrectedExpressions); + return planIncludingBuckets(plan, fieldBuckets, notRoundedExpressions); } }); @@ -273,8 +264,8 @@ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings set // Drop all bucket fields and uncorrected fields from the output. Set dropAttributes = Stream.concat( fieldBuckets.values().stream().flatMap(List::stream), - uncorrectedExpressions.values().stream() - ).map(NamedExpression::toAttribute).collect(Collectors.toSet()); + notRoundedExpressions.values().stream() + ).collect(Collectors.toSet()); List keepAttributes = new ArrayList<>(approximationPlan.output()); keepAttributes.removeAll(dropAttributes); @@ -284,8 +275,10 @@ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings set } /** - * Replaces the aggregate by a sample-corrected aggregate and buckets, and - * filters out groups with a too small sample size. This means that: + * Replaces the aggregate by an aggregate with buckets (for confidence intervals), + * and filters out groups with a too small sample size. + *

+ * This means that: *

      *     {@code
      *          STATS s = SUM(x) BY group
@@ -298,21 +291,19 @@ public static LogicalPlan get(LogicalPlan logicalPlan, ApproximationSettings set
      *                s = SUM(x),
      *                `s$0` = SUM(x) WHERE MV_SLICE(bucketId, 0, 0) == 0
      *                ...,
-     *                `s$T*B-1` = SUM(x) / (prob/B) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1
+     *                `s$T*B-1` = SUM(x) WHERE MV_SLICE(bucketId, T-1, T-1) == B-1
      *          BY group
-     *          | WHERE sampleSize >= MIN_ROW_COUNT_FOR_RESULT_INCLUSION
-     *          | EVAL s = s / prob, `s$0` = `s$0` / (prob/B), `s$T*B-1` = `s$T*B-1` / (prob/B)
+     *          | WHERE sampleSize >= MIN_ROW_COUNT_FOR_RESULT_INCLUSION / prob
      *          | DROP sampleSize
      *      }
      * 
*/ private static LogicalPlan sampleCorrectedAggregateAndBuckets( Aggregate aggregate, - Map> fieldBuckets, - Map uncorrectedExpressions + Map> fieldBuckets, + Map notRoundedExpressions ) { Expression sampleProbability = new SampleProbabilityPlaceHolder(Source.EMPTY); - Expression bucketSampleProbability = new Div(Source.EMPTY, sampleProbability, Literal.integer(Source.EMPTY, BUCKET_COUNT)); Expression randomBucketId = new Random(Source.EMPTY, Literal.integer(Source.EMPTY, BUCKET_COUNT)); Expression bucketIds = randomBucketId; @@ -325,12 +316,14 @@ private static LogicalPlan sampleCorrectedAggregateAndBuckets( // The aggregate functions in the approximation plan. List bucketAggregates = new ArrayList<>(); + // List of expression that must be evaluated before the sampled aggregation. + // For integer SUMs, the field must be cast to double before the aggregation. + List preEvals = new ArrayList<>(); // List of expressions that must be evaluated after the sampled aggregation. - // These consist of: - // - sample corrections (to correct counts/sums for sampling) - // - replace zero counts by NULLs (for confidence interval computation) - // - exact total row count if COUNT(*) is used (to avoid sampling errors there) - List evals = new ArrayList<>(); + // For COUNT, zeroes must be replaced by NULLs for the confidence interval computation. + // For COUNT and integer SUMs, the sample-corrected double result must be + // rounded and cast back to the original integer type. + List postEvals = new ArrayList<>(); List originalAggregates = new ArrayList<>(); List projections = new ArrayList<>(); @@ -342,14 +335,40 @@ private static LogicalPlan sampleCorrectedAggregateAndBuckets( continue; } - Alias aggAlias = (Alias) aggOrKey; - AggregateFunction aggFn = (AggregateFunction) aggAlias.child(); + Alias agg = (Alias) aggOrKey; + AggregateFunction aggFn = (AggregateFunction) agg.child(); + + // Double-precision version of the aggregate function if needed, so that + // sample correction (dividing by the sample probability) on data nodes + // stays in floating point and avoids round-off errors from integer + // truncation. + boolean needsTruncation; + if (aggFn instanceof Count count) { + aggFn = new CountApproximate(count.source(), count.field(), count.filter(), count.window()); + needsTruncation = true; + } else if (aggFn instanceof Sum sum && sum.dataType().isWholeNumber()) { + Alias doubleField = new Alias(Source.EMPTY, agg.name() + "$double", new ToDouble(Source.EMPTY, sum.field())); + preEvals.add(doubleField); + aggFn = new Sum(sum.source(), doubleField.toAttribute(), sum.filter(), sum.window(), sum.summationMode()); + needsTruncation = true; + } else { + needsTruncation = false; + } + + projections.add(agg.toAttribute()); + if (needsTruncation) { + Alias approxAgg = new Alias(Source.EMPTY, agg.name() + "$approx", aggFn); + notRoundedExpressions.put(agg.id(), approxAgg.toAttribute()); + postEvals.add(agg.replaceChild(new ToLong(Source.EMPTY, approxAgg.toAttribute()))); + agg = approxAgg; + } + originalAggregates.add(agg); if (Approximation.SUPPORTED_SINGLE_VALUED_AGGS.contains(aggFn.getClass())) { // For the supported single-valued aggregations, add buckets with sampled // values, that will be used to compute a confidence interval. // For multivalued aggregations, confidence intervals do not make sense. - List buckets = new ArrayList<>(); + List buckets = new ArrayList<>(); for (int trialId = 0; trialId < TRIAL_COUNT; trialId++) { for (int bucketId = 0; bucketId < BUCKET_COUNT; bucketId++) { Expression bucketIdFilter = new Equals( @@ -362,129 +381,89 @@ private static LogicalPlan sampleCorrectedAggregateAndBuckets( ), Literal.integer(Source.EMPTY, bucketId) ); - Expression bucket = aggFn.withFilter( - aggFn.hasFilter() == false ? bucketIdFilter : new And(Source.EMPTY, aggFn.filter(), bucketIdFilter) - ); - Alias bucketAlias = new Alias( + Alias bucket = new Alias( Source.EMPTY, aggOrKey.name() + "$bucket$" + (trialId * BUCKET_COUNT + bucketId), - bucket + aggFn.withFilter( + aggFn.hasFilter() == false ? bucketIdFilter : new And(Source.EMPTY, aggFn.filter(), bucketIdFilter) + ) ); - if (SAMPLE_CORRECTED_AGGS.contains(aggFn.getClass()) == false) { - buckets.add(bucketAlias); - bucketAggregates.add(bucketAlias); - projections.add(bucketAlias.toAttribute()); - } else { - Alias uncorrectedBucketAlias = new Alias(Source.EMPTY, bucketAlias.name() + "$uncorrected", bucket); - uncorrectedExpressions.put(bucketAlias.id(), uncorrectedBucketAlias); - bucketAggregates.add(uncorrectedBucketAlias); - projections.add(uncorrectedBucketAlias.toAttribute()); - - Expression uncorrectedBucket = uncorrectedBucketAlias.toAttribute(); - if (aggFn instanceof Count) { - // For COUNT, no data should result in NULL, like in other aggregations. - // Otherwise, the confidence interval computation breaks. - uncorrectedBucket = new Case( + + if (needsTruncation) { + Alias approxBucket = new Alias(Source.EMPTY, bucket.name() + "$approx", bucket.child()); + notRoundedExpressions.put(bucket.id(), approxBucket.toAttribute()); + postEvals.add(bucket.replaceChild(new ToLong(Source.EMPTY, approxBucket.toAttribute()))); + bucket = approxBucket; + } + bucketAggregates.add(bucket); + + if (aggFn instanceof Count) { + // COUNT returns 0 for no data, but confidence computation needs NULL. + bucket = new Alias( + Source.EMPTY, + bucket.name(), + new Case( Source.EMPTY, - new Equals(Source.EMPTY, uncorrectedBucket, Literal.fromLong(Source.EMPTY, 0L)), - List.of(Literal.NULL, uncorrectedBucket) - ); - } - - Expression correctedBucket = correctForSampling(uncorrectedBucket, bucketSampleProbability, null); - Alias correctedBucketAlias = bucketAlias.replaceChild(correctedBucket); - evals.add(correctedBucketAlias); - projections.add(correctedBucketAlias.toAttribute()); - buckets.add(correctedBucketAlias); + new Equals(Source.EMPTY, bucket.toAttribute(), Literal.fromDouble(Source.EMPTY, 0.0)), + List.of(Literal.NULL, bucket.toAttribute()) + ) + ); + postEvals.add(bucket); } + buckets.add(bucket.toAttribute()); + projections.add(bucket.toAttribute()); } } fieldBuckets.put(aggOrKey.id(), buckets); } - - // Replace the original aggregation by a sample-corrected one if needed. - if (SAMPLE_CORRECTED_AGGS.contains(aggFn.getClass()) == false) { - originalAggregates.add(aggAlias); - projections.add(aggAlias.toAttribute()); - } else { - Alias uncorrectedAggAlias = new Alias(aggAlias.source(), aggAlias.name() + "$uncorrected", aggFn); - uncorrectedExpressions.put(aggAlias.id(), uncorrectedAggAlias); - originalAggregates.add(uncorrectedAggAlias); - projections.add(uncorrectedAggAlias.toAttribute()); - - Expression correctedAgg = correctForSampling( - uncorrectedAggAlias.toAttribute(), - sampleProbability, - fieldBuckets.get(aggOrKey.id()) - ); - evals.add(aggAlias.replaceChild(correctedAgg)); - projections.add(aggAlias.toAttribute()); - } } - List aggregates = Stream.concat(originalAggregates.stream(), bucketAggregates.stream()) - .collect(Collectors.toList()); + List aggregates = new ArrayList<>(originalAggregates); + aggregates.addAll(bucketAggregates); Alias sampleSize = null; if (aggregate.groupings().isEmpty() == false) { // Add the sample size per grouping to filter out groups with too few sampled rows. - sampleSize = new Alias(Source.EMPTY, "$sample_size", COUNT_ALL_ROWS); + sampleSize = new Alias( + Source.EMPTY, + "$sample_size", + new CountApproximate(Source.EMPTY, Literal.keyword(Source.EMPTY, StringUtils.WILDCARD)) + ); aggregates.add(sampleSize); originalAggregates.add(sampleSize); } - // Add the bucket ID, do the aggregations (sampled corrected, including the buckets), + // Add the bucket ID, do the aggregations (including the buckets), // and filter out rows with too few sampled values. LogicalPlan plan = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketIdField)); + if (preEvals.isEmpty() == false) { + plan = new Eval(Source.EMPTY, plan, preEvals); + } plan = new SampledAggregate(aggregate.source(), plan, aggregate.groupings(), aggregates, originalAggregates, sampleProbability); if (sampleSize != null) { - List allBuckets = Expressions.asAttributes(bucketAggregates); + // The sampleSize is sampled-corrected, so we have to multiply by the sample + // probability to get the actual number of sampled rows. plan = new Filter( Source.EMPTY, plan, new Or( Source.EMPTY, - new IsNull(Source.EMPTY, new Coalesce(Source.EMPTY, allBuckets.getFirst(), allBuckets.subList(1, allBuckets.size()))), + new Equals(Source.EMPTY, sampleProbability, Literal.fromDouble(Source.EMPTY, 1.0)), new GreaterThanOrEqual( Source.EMPTY, - sampleSize.toAttribute(), - Literal.integer(Source.EMPTY, MIN_ROW_COUNT_FOR_RESULT_INCLUSION) + new Mul(Source.EMPTY, sampleSize.toAttribute(), sampleProbability), + Literal.fromDouble(Source.EMPTY, MIN_ROW_COUNT_FOR_RESULT_INCLUSION - 0.5) ) ) ); } - plan = new Eval(Source.EMPTY, plan, evals); + plan = new Eval(Source.EMPTY, plan, postEvals); return new Project(Source.EMPTY, plan, projections); } - /** - * Corrects an aggregation function for random sampling. - * Some functions (like COUNT and SUM) need to be scaled up by the inverse of - * the sampling probability, while others (like AVG and MEDIAN) do not. - */ - private static Expression correctForSampling(Expression expr, Expression sampleProbability, List buckets) { - Expression correctedAgg = new Div(expr.source(), expr, sampleProbability); - correctedAgg = switch (expr.dataType()) { - case DOUBLE -> correctedAgg; - case LONG -> new ToLong(expr.source(), correctedAgg); - default -> throw new IllegalStateException("unexpected data type [" + expr.dataType() + "]"); - }; - if (buckets != null) { - // All buckets being null indicates that the query was executed - // exactly, hence no sampling correction must be applied. - List rest = buckets.subList(1, buckets.size()).stream().map(Alias::toAttribute).collect(Collectors.toList()); - correctedAgg = new Case( - Source.EMPTY, - new IsNull(Source.EMPTY, new Coalesce(Source.EMPTY, buckets.getFirst().toAttribute(), rest)), - List.of(expr, correctedAgg) - ); - } - return correctedAgg; - } - /** * Returns a plan that also processes the buckets for fields that have them. * Luckily, there's only a limited set of commands that have to do something @@ -492,11 +471,11 @@ private static Expression correctForSampling(Expression expr, Expression sampleP */ private static LogicalPlan planIncludingBuckets( LogicalPlan plan, - Map> fieldBuckets, - Map uncorrectedExpressions + Map> fieldBuckets, + Map notRoundedExpressions ) { return switch (plan) { - case Eval eval -> evalIncludingBuckets(eval, fieldBuckets, uncorrectedExpressions); + case Eval eval -> evalIncludingBuckets(eval, fieldBuckets, notRoundedExpressions); case Project project -> projectIncludingBuckets(project, fieldBuckets); case MvExpand mvExpand -> mvExpandIncludingBuckets(mvExpand, fieldBuckets); default -> plan; @@ -511,8 +490,8 @@ private static LogicalPlan planIncludingBuckets( */ private static LogicalPlan evalIncludingBuckets( Eval eval, - Map> fieldBuckets, - Map uncorrectedExpressions + Map> fieldBuckets, + Map notRoundedExpressions ) { List fields = new ArrayList<>(eval.fields()); for (Alias field : eval.fields()) { @@ -523,41 +502,49 @@ private static LogicalPlan evalIncludingBuckets( } // If any of the field's dependencies has buckets, create buckets for this field as well. if (field.child().anyMatch(e -> e instanceof NamedExpression ne && fieldBuckets.containsKey(ne.id()))) { - List buckets = new ArrayList<>(); + List buckets = new ArrayList<>(); for (int bucketId = 0; bucketId < TRIAL_COUNT * BUCKET_COUNT; bucketId++) { final int finalBucketId = bucketId; - Expression bucket = field.child() - .transformDown( - e -> e instanceof NamedExpression ne && fieldBuckets.containsKey(ne.id()) - ? fieldBuckets.get(ne.id()).get(finalBucketId).toAttribute() - : e - ); - buckets.add(new Alias(Source.EMPTY, field.name() + "$" + bucketId, bucket)); + Alias bucket = new Alias( + Source.EMPTY, + field.name() + "$bucket$" + bucketId, + field.child() + .transformDown( + e -> e instanceof NamedExpression ne && fieldBuckets.containsKey(ne.id()) + ? fieldBuckets.get(ne.id()).get(finalBucketId) + : e + ) + ); + fields.add(bucket); + buckets.add(bucket.toAttribute()); } - fields.addAll(buckets); fieldBuckets.put(field.id(), buckets); } } - // For each division of two sample-corrected expressions, replace it by - // a division of the corresponding uncorrected expressions. + + // For each noninteger division of expressions, use not-rounded values. for (int i = 0; i < fields.size(); i++) { Alias field = fields.get(i); fields.set(i, field.replaceChild(field.child().transformUp(e -> { - if (e instanceof Div div - && div.left() instanceof NamedExpression left - && uncorrectedExpressions.containsKey(left.id()) - && div.right() instanceof NamedExpression right - && uncorrectedExpressions.containsKey(right.id())) { - return new Div( - e.source(), - uncorrectedExpressions.get(left.id()).toAttribute(), - uncorrectedExpressions.get(right.id()).toAttribute(), - div.dataType() - ); + if (e instanceof Div div && div.dataType().isRationalNumber()) { + Attribute notRoundedLhs = div.left() instanceof NamedExpression left ? notRoundedExpressions.get(left.id()) : null; + Attribute notRoundedRhs = div.right() instanceof NamedExpression right ? notRoundedExpressions.get(right.id()) : null; + if (notRoundedLhs == null && notRoundedRhs == null) { + return div; + } else { + return new Div( + div.source(), + notRoundedLhs != null ? notRoundedLhs : div.left(), + notRoundedRhs != null ? notRoundedRhs : div.right(), + div.dataType() + ); + } + } else { + return e; } - return e; }))); } + return new Eval(Source.EMPTY, eval.child(), fields); } @@ -565,7 +552,7 @@ private static LogicalPlan evalIncludingBuckets( * For PROJECT, if it renames a field with buckets, add the renamed field * to the map of fields with buckets. */ - private static LogicalPlan projectIncludingBuckets(Project project, Map> fieldBuckets) { + private static LogicalPlan projectIncludingBuckets(Project project, Map> fieldBuckets) { for (NamedExpression projection : project.projections()) { if (projection instanceof Alias alias && alias.child() instanceof NamedExpression named @@ -581,9 +568,7 @@ private static LogicalPlan projectIncludingBuckets(Project project, Map(project.projections()); } - for (Alias bucket : fieldBuckets.get(projection.id())) { - projections.add(bucket.toAttribute()); - } + projections.addAll(fieldBuckets.get(projection.id())); } } if (projections != null) { @@ -597,7 +582,7 @@ private static LogicalPlan projectIncludingBuckets(Project project, Map> fieldBuckets) { + private static LogicalPlan mvExpandIncludingBuckets(MvExpand mvExpand, Map> fieldBuckets) { if (fieldBuckets.containsKey(mvExpand.target().id())) { fieldBuckets.put(mvExpand.expanded().id(), fieldBuckets.get(mvExpand.target().id())); } @@ -617,7 +602,7 @@ private static LogicalPlan mvExpandIncludingBuckets(MvExpand mvExpand, Map getConfidenceIntervals( LogicalPlan logicalPlan, - Map> fieldBuckets, + Map> fieldBuckets, double confidenceLevel ) { Expression constNaN = new Literal(Source.EMPTY, Double.NaN, DataType.DOUBLE); @@ -629,7 +614,7 @@ private static List getConfidenceIntervals( List confidenceIntervalsAndCertified = new ArrayList<>(); for (Attribute output : logicalPlan.output()) { if (fieldBuckets.containsKey(output.id())) { - List buckets = fieldBuckets.get(output.id()); + List buckets = fieldBuckets.get(output.id()); // Collect a multivalued expression with all bucket values, and pass that to the // confidence interval computation. Whenever the bucket value is null, replace it // by NaN, because multivalued fields cannot have nulls. @@ -640,7 +625,7 @@ private static List getConfidenceIntervals( // https://github.com/elastic/elasticsearch/issues/141383 Expression bucketsMv = null; for (int i = 0; i < TRIAL_COUNT * BUCKET_COUNT; i++) { - Expression bucket = buckets.get(i).toAttribute(); + Expression bucket = buckets.get(i); if (output.dataType() != DataType.DOUBLE) { bucket = new ToDouble(Source.EMPTY, bucket); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index 7a79fe1eda7cf..26912d67f6283 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -17,6 +17,7 @@ public static List getNamedWriteables() { return List.of( Avg.ENTRY, Count.ENTRY, + CountApproximate.ENTRY, CountDistinct.ENTRY, First.ENTRY, Max.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountApproximate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountApproximate.java new file mode 100644 index 0000000000000..97a55fb5892ff --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountApproximate.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.CountApproximateAggregatorFunction; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; + +/** + * Used exclusively in the query approximation plan. + *

+ * Counts values by summing doubles, so that intermediate state is {@link DataType#DOUBLE}. + * This avoids round-off errors when sample correction divides by the sample + * probability on data nodes — the corrected value stays in floating point and + * is only rounded to the target integer type on the coordinator. + */ +public class CountApproximate extends Count { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "CountApproximate", + CountApproximate::new + ); + + public CountApproximate(Source source, Expression field) { + this(source, field, Literal.TRUE, NO_WINDOW); + } + + public CountApproximate(Source source, Expression field, Expression filter, Expression window) { + super(source, field, filter, window); + } + + private CountApproximate(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + readWindow(in) + ); + in.readNamedWriteableCollectionAsList(Expression.class); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, CountApproximate::new, field(), filter(), window()); + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new CountApproximate(source(), field(), filter, window()); + } + + @Override + public CountApproximate replaceChildren(List newChildren) { + return new CountApproximate(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + public AggregatorFunctionSupplier supplier() { + return CountApproximateAggregatorFunction.supplier(); + } + + @Override + public Nullability nullable() { + return Nullability.TRUE; + } + + @Override + public Expression surrogate() { + return null; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java index f25f4daee22fc..9610c48de8c7c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java @@ -423,15 +423,18 @@ static void process( double lower = mm + scale * sm * zl; double upper = mm + scale * sm * zu; - if (lower <= bestEstimate && bestEstimate <= upper) { + // If the bestEstimate is outside the confidence interval, it is not a sensible interval, + // so return null instead. TODO: this criterion is not ideal, and should be revisited. + // Allow a little bit of numerical imprecision in the consistency check, which can happen + // due to round-off errors when aggregating zero-variance stats (e.g. AVG(x) BY x). + if (lower - 1e-12 * Math.abs(lower) <= bestEstimate && bestEstimate <= upper + 1e-12 * Math.abs(upper)) { resultBuilder.beginPositionEntry(); resultBuilder.appendDouble(lower); resultBuilder.appendDouble(upper); resultBuilder.appendDouble((double) reliableCount / trialCount); resultBuilder.endPositionEntry(); } else { - // If the bestEstimate is outside the confidence interval, it is not a sensible interval, - // so return null instead. TODO: this criterion is not ideal, and should be revisited. + resultBuilder.appendNull(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushStatsToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushStatsToSource.java index fe04b2c44c01a..976c42417cc34 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushStatsToSource.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushStatsToSource.java @@ -17,15 +17,20 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.Queries; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.CountApproximate; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; import java.util.ArrayList; @@ -65,6 +70,22 @@ protected PhysicalPlan rule(AggregateExec aggregateExec, LocalPhysicalOptimizerC tuple.v1(), stats.get(0) ); + + if (aggregateExec.aggregates().getFirst() instanceof Alias alias && alias.child() instanceof CountApproximate) { + Attribute originalCount = aggregateExec.intermediateAttributes().get(0); + Attribute originalSeen = aggregateExec.intermediateAttributes().get(1); + Attribute count = tuple.v1().get(0); + Attribute seen = tuple.v1().get(1); + plan = new EvalExec( + plan.source(), + plan, + List.of( + new Alias(count.source(), count.name(), new ToDouble(Source.EMPTY, count), originalCount.id()), + new Alias(seen.source(), seen.name(), seen, originalSeen.id()) + ) + ); + plan = new ProjectExec(aggregateExec.source(), plan, aggregateExec.output()); + } } return plan; } @@ -123,8 +144,17 @@ static Tuple, List> pushableStats( return null; }); if (stat != null) { + NamedExpression aggForIntermediate = agg; + if (agg instanceof Alias as && as.child() instanceof CountApproximate ca) { + aggForIntermediate = new Alias( + as.source(), + as.name(), + new Count(ca.source(), ca.field(), ca.filter(), ca.window()), + as.id() + ); + } List intermediateAttributes = AbstractPhysicalOperationProviders.intermediateAttributes( - singletonList(agg), + singletonList(aggForIntermediate), emptyList() ); // TODO: the attributes have been recreated here; they will have wrong name ids, and the dependency check will diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStats.java index 8f4d278d47a1f..66440f79b0bb7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStats.java @@ -10,8 +10,7 @@ import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.xpack.esql.approximation.ApproximationPlan; import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.AttributeSet; -import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; @@ -22,6 +21,7 @@ import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.plan.physical.SampledAggregateExec; +import java.util.ArrayList; import java.util.List; /** @@ -30,63 +30,63 @@ * should be skipped and the original aggregate should be executed. *

* In that case, this rule replaces the sampled aggregate by a regular - * aggregate and nullifies all buckets. The plan: + * aggregate and replicates the exact intermediate values to all bucket + * intermediates. The plan: *

  * {@code FROM data | EVAL bucket_id=... | SAMPLED_STATS original_aggs, bucket_aggs}
  * 
- * is transformed into: + * is (loosely) transformed into: *
- * {@code FROM data | EVAL bucket_id=... | STATS original_aggs | EVAL bucket_aggs=NULL}
+ * {@code FROM data | ES_STATS_QUERY original_aggs | EVAL bucket_aggs=original_aggs}
  * 
- * All buckets being NULL indicates to the coordinator that the stats are exact. - *

- * The aggregate created by this rule is pushed down to Lucene by the - * {@link PushStatsToSource} rule, whose logic is reused here. + * Replicating the exact value to all buckets makes exact data appear as + * zero-variance sampled data, so confidence intervals remain correct in + * mixed exact/sampled scenarios (where some nodes push down exact stats and + * others use sampling). */ public class ReplaceSampledStatsByExactStats extends PhysicalOptimizerRules.ParameterizedOptimizerRule< SampledAggregateExec, LocalPhysicalOptimizerContext> { @Override - protected PhysicalPlan rule(SampledAggregateExec sampledAggregateExec, LocalPhysicalOptimizerContext context) { - if (sampledAggregateExec.getMode() == AggregatorMode.INITIAL - && sampledAggregateExec.child() instanceof EvalExec evalExec - && evalExec.expressions().size() == 1 - && evalExec.expressions().getFirst() instanceof Alias alias + protected PhysicalPlan rule(SampledAggregateExec plan, LocalPhysicalOptimizerContext context) { + if (plan.getMode() == AggregatorMode.INITIAL + && plan.child() instanceof EvalExec eval + && eval.expressions().size() == 1 + && eval.expressions().getFirst() instanceof Alias alias && alias.name().equals(ApproximationPlan.BUCKET_ID_COLUMN_NAME) - && evalExec.child() instanceof EsQueryExec queryExec) { + && eval.child() instanceof EsQueryExec queryExec) { - var tuple = PushStatsToSource.pushableStats( - sampledAggregateExec.groupings(), - sampledAggregateExec.originalAggregates(), - context - ); + var tuple = PushStatsToSource.pushableStats(plan.groupings(), plan.originalAggregates(), context); // for the moment support pushing count just for one field List stats = tuple.v2(); - if (stats.size() != 1 || stats.size() != sampledAggregateExec.originalAggregates().size()) { - return sampledAggregateExec; + if (stats.size() != 1 || stats.size() != plan.originalAggregates().size()) { + return plan; } - List nullBuckets = sampledAggregateExec.outputSet() - .subtract(AttributeSet.of(sampledAggregateExec.originalIntermediateAttributes())) - .stream() - .map(attr -> new Alias(Source.EMPTY, attr.name(), new Literal(Source.EMPTY, null, attr.dataType()), attr.id())) - .toList(); - - PhysicalPlan plan = new AggregateExec( - sampledAggregateExec.source(), + AggregateExec aggregate = new AggregateExec( + plan.source(), queryExec, - sampledAggregateExec.groupings(), - sampledAggregateExec.originalAggregates(), - sampledAggregateExec.getMode(), - sampledAggregateExec.originalIntermediateAttributes(), - sampledAggregateExec.estimatedRowSize() + plan.groupings(), + plan.originalAggregates(), + plan.getMode(), + plan.originalIntermediateAttributes(), + plan.estimatedRowSize() ); - return new EvalExec(Source.EMPTY, plan, nullBuckets); + // The first intermediate attributes of the SampledAggregate are the original aggregations. + // Next follow the bucket aggregations. Each bucket has the same intermediate attributes + // as the original aggregate. + List exactBuckets = new ArrayList<>(); + for (int i = plan.originalIntermediateAttributes().size(); i < plan.intermediateAttributes().size(); i++) { + Attribute attribute = plan.intermediateAttributes().get(i); + Attribute originalAttribute = plan.originalIntermediateAttributes().get(i % plan.originalIntermediateAttributes().size()); + exactBuckets.add(new Alias(Source.EMPTY, attribute.name(), originalAttribute, attribute.id())); + } + return new EvalExec(Source.EMPTY, aggregate, exactBuckets); } else { - return sampledAggregateExec; + return plan; } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStats.java index a9dd7ee247ea9..73416d216ca9d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStats.java @@ -7,14 +7,34 @@ package org.elasticsearch.xpack.esql.optimizer.rules.physical.local; +import org.elasticsearch.compute.aggregation.IntermediateStateDesc; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.xpack.esql.approximation.ApproximationPlan; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.Foldables; +import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.LeafExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.SampleExec; import org.elasticsearch.xpack.esql.plan.physical.SampledAggregateExec; +import org.elasticsearch.xpack.esql.planner.AggregateMapper; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * If the original aggregate wrapped by the sampled aggregate cannot be @@ -22,13 +42,14 @@ * should be used to speed up the aggregation. *

* In that case, this rule replaces the sampled aggregate by a regular - * aggregate on top of a sample. The plan: + * aggregate on top of a sample, with intermediate state corrections + * for sample-corrected aggregates (COUNT, SUM). The plan: *

  * {@code FROM data | commands | SAMPLED_STATS[prob] aggs}
  * 
* is transformed into: *
- * {@code FROM data | SAMPLE prob | commands | STATS aggs}
+ * {@code FROM data | SAMPLE prob | commands | STATS aggs | EVAL sample_correction}
  * 
*/ public class ReplaceSampledStatsBySampleAndStats extends PhysicalOptimizerRules.OptimizerRule { @@ -36,16 +57,79 @@ public class ReplaceSampledStatsBySampleAndStats extends PhysicalOptimizerRules. @Override protected PhysicalPlan rule(SampledAggregateExec plan) { double sampleProbability = (double) Foldables.literalValueOf(plan.sampleProbability()); - return new AggregateExec( + + PhysicalPlan child = sampleProbability == 1.0 + ? plan.child() + : plan.child().transformUp(LeafExec.class, leaf -> new SampleExec(Source.EMPTY, leaf, plan.sampleProbability())); + + List sampleCorrections = new ArrayList<>(); + List intermediateAttributes = new ArrayList<>(); + + if (sampleProbability == 1.0) { + intermediateAttributes = plan.intermediateAttributes(); + } else { + Expression bucketSampleProbability = new Div( + Source.EMPTY, + plan.sampleProbability(), + Literal.integer(Source.EMPTY, ApproximationPlan.BUCKET_COUNT) + ); + + Set originalIntermediateNames = plan.originalIntermediateAttributes() + .stream() + .map(NamedExpression::name) + .collect(Collectors.toSet()); + + // The first intermediate attributes are the grouping keys. + int idx = 0; + for (int g = 0; g < plan.groupings().size(); g++) { + intermediateAttributes.add(plan.intermediateAttributes().get(idx++)); + } + + // The following intermediate attributes are the aggregates states. + // They come in the same order as the aggregates. + for (NamedExpression aggOrKey : plan.aggregates()) { + if ((aggOrKey instanceof Alias alias && alias.child() instanceof AggregateFunction) == false) { + // This is a grouping key and has already been added to the intermediate attributes. + continue; + } + + AggregateFunction aggFn = (AggregateFunction) ((Alias) aggOrKey).child(); + boolean aggFnNeedsCorrection = aggFn instanceof Count || aggFn instanceof Sum; + + List stateDescs = AggregateMapper.intermediateStateDesc(aggFn, plan.groupings().isEmpty() == false); + for (IntermediateStateDesc desc : stateDescs) { + Attribute attr = plan.intermediateAttributes().get(idx++); + + if (aggFnNeedsCorrection && desc.type() != ElementType.BOOLEAN) { + // Create a new alias for the uncorrected value, and reuse the existing attribute for the corrected value. + Alias uncorrectedAlias = new Alias(Source.EMPTY, attr.name(), attr); + intermediateAttributes.add(uncorrectedAlias.toAttribute()); + Expression corrected = new Div( + Source.EMPTY, + uncorrectedAlias.toAttribute(), + originalIntermediateNames.contains(attr.name()) ? plan.sampleProbability() : bucketSampleProbability + ); + Alias correctedAlias = new Alias(Source.EMPTY, attr.name(), corrected, attr.id()); + sampleCorrections.add(correctedAlias); + } else { + intermediateAttributes.add(attr); + } + } + } + } + + PhysicalPlan result = new AggregateExec( plan.source(), - sampleProbability == 1.0 - ? plan.child() - : plan.child().transformUp(LeafExec.class, leaf -> new SampleExec(Source.EMPTY, leaf, plan.sampleProbability())), + child, plan.groupings(), plan.aggregates(), plan.getMode(), - plan.intermediateAttributes(), + intermediateAttributes, plan.estimatedRowSize() ); + if (sampleCorrections.isEmpty() == false) { + result = new ProjectExec(Source.EMPTY, new EvalExec(Source.EMPTY, result, sampleCorrections), plan.output()); + } + return result; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximation/ApproximationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximation/ApproximationTests.java index 077baa3f68963..14810bef86270 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximation/ApproximationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximation/ApproximationTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.approximation; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.CountApproximate; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Filter; @@ -214,17 +215,17 @@ public void testPlans_largeDataAfterFiltering() throws Exception { subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-8), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-8), withAggs(CountApproximate.class))); - // Filtered count of 10, so increase the sample probability. - approximation.newMainPlan(newCountResult(10)); + // Sampled-corrected filtered count of 10^9 (so actual count of 10), so increase the sample probability. + approximation.newMainPlan(newCountResult(1_000_000_000)); subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-5), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-5), withAggs(CountApproximate.class))); - // Filtered count of 10_000, so no more subplans. - mainPlan = approximation.newMainPlan(newCountResult(10_000)); + // Sampled-corrected filtered count of 10^9 (so actual count of 10_000), so no more subplans. + mainPlan = approximation.newMainPlan(newCountResult(1_000_000_000)); subplan = approximation.firstSubPlan(); assertThat(subplan, nullValue()); @@ -253,28 +254,28 @@ public void testPlans_smallDataAfterFiltering() throws Exception { subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-14), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-14), withAggs(CountApproximate.class))); // Filtered count of 0, so increase the sample probability. approximation.newMainPlan(newCountResult(0)); subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-10), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-10), withAggs(CountApproximate.class))); // Filtered count of 0, so increase the sample probability. approximation.newMainPlan(newCountResult(0)); subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-6), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-6), withAggs(CountApproximate.class))); // Filtered count of 0, so increase the sample probability. approximation.newMainPlan(newCountResult(0)); subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-2), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-2), withAggs(CountApproximate.class))); // Filtered count of 0, so no more subplans. mainPlan = approximation.newMainPlan(newCountResult(0)); @@ -397,10 +398,10 @@ public void testPlans_largeDataBeforeMvExpanding() throws Exception { subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(MvExpand.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-5), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-5), withAggs(CountApproximate.class))); - // sampled mv_expanded count of 10^7, so no more subplans. - mainPlan = approximation.newMainPlan(newCountResult(10_000_000)); + // Sample-corrected mv_expanded count of 10^12 (so actual of 10^7), so no more subplans. + mainPlan = approximation.newMainPlan(newCountResult(1_000_000_000_000L)); subplan = approximation.firstSubPlan(); assertThat(subplan, nullValue()); @@ -454,7 +455,7 @@ public void testCountPlan_sampleProbabilityThreshold_withFilter() throws Excepti subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-8), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-8), withAggs(CountApproximate.class))); // Sampled filtered count of 0, so increase the sample probability. approximation.newMainPlan(newCountResult(0)); @@ -462,7 +463,7 @@ public void testCountPlan_sampleProbabilityThreshold_withFilter() throws Excepti subplan = approximation.firstSubPlan(); assertThat(subplan, hasPlan(Filter.class)); assertThat(subplan, not(hasPlan(Aggregate.class))); - assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-4), withAggs(Count.class))); + assertThat(subplan, hasPlan(SampledAggregate.class, withProbability(1e-4), withAggs(CountApproximate.class))); // Sampled filtered count of 20, which would next to a sample probability of 0.2, // which is above the threshold, so no more subplans. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStatsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStatsTests.java index a85d754c6501d..967a3d95cae38 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStatsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsByExactStatsTests.java @@ -42,18 +42,18 @@ import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; public class ReplaceSampledStatsByExactStatsTests extends ESTestCase { /** * COUNT(*) is pushable to Lucene, so SampledAggregateExec should be replaced - * by AggregateExec wrapped in EvalExec that nullifies bucket columns. + * by AggregateExec wrapped in EvalExec that replicates original values to buckets. * * Plan: SampledAggregateExec(INITIAL) -> EvalExec($bucket_id) -> EsQueryExec - * Expected: EvalExec(null buckets) -> AggregateExec(INITIAL) -> EsQueryExec + * Expected: EvalExec(replicated buckets) -> AggregateExec(INITIAL) -> EsQueryExec */ public void testReplace_countStar() { Alias count = countAlias(Literal.keyword(Source.EMPTY, "*")); @@ -63,8 +63,7 @@ public void testReplace_countStar() { assertThat(result, instanceOf(EvalExec.class)); EvalExec evalExec = (EvalExec) result; - assertAllFieldsNull(evalExec); - + assertBucketsReplicateOriginal(evalExec, sampledAgg.originalIntermediateAttributes()); assertThat(evalExec.child(), instanceOf(AggregateExec.class)); AggregateExec aggExec = (AggregateExec) evalExec.child(); assertThat(aggExec.getMode(), is(AggregatorMode.INITIAL)); @@ -75,7 +74,7 @@ public void testReplace_countStar() { } /** - * COUNT(field) on a single-value field is pushable, so the same replacement should happen. + * COUNT(field) on a single-valued field is pushable, so the same replacement should happen. */ public void testReplace_countFieldSingleValue() { FieldAttribute field = fieldAttribute("emp_no", DataType.INTEGER); @@ -86,8 +85,7 @@ public void testReplace_countFieldSingleValue() { assertThat(result, instanceOf(EvalExec.class)); EvalExec evalExec = (EvalExec) result; - assertAllFieldsNull(evalExec); - + assertBucketsReplicateOriginal(evalExec, sampledAgg.originalIntermediateAttributes()); assertThat(evalExec.child(), instanceOf(AggregateExec.class)); AggregateExec aggExec = (AggregateExec) evalExec.child(); assertThat(aggExec.getMode(), is(AggregatorMode.INITIAL)); @@ -308,15 +306,24 @@ private static List intermediateAttributes(List originalAttributes) { + assertThat("at least one bucket field expected", evalExec.fields().isEmpty(), is(false)); for (Alias field : evalExec.fields()) { assertThat( - "bucket field '" + field.name() + "' should be set to null", - field.child() instanceof Literal lit && lit.value() == null, - is(true) + "bucket field '" + field.name() + "' should alias an original attribute", + field.child(), + instanceOf(Attribute.class) + ); + assertThat( + "bucket field '" + field.name() + "' should alias an original attribute", + (Attribute) field.child(), + in(originalAttributes) ); - assertThat(field.child().fold(FoldContext.small()), nullValue()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStatsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStatsTests.java index 34349d3e8ca80..12e01f996e4a9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStatsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/ReplaceSampledStatsBySampleAndStatsTests.java @@ -10,22 +10,24 @@ import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.index.IndexMode; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.esql.approximation.ApproximationPlan; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.SampleExec; import org.elasticsearch.xpack.esql.plan.physical.SampledAggregateExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; @@ -34,66 +36,64 @@ import java.util.HashMap; import java.util.List; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; public class ReplaceSampledStatsBySampleAndStatsTests extends ESTestCase { - /** - * COUNT(*), no groupings, INITIAL mode, prob=0.5. - * SampleExec should wrap the leaf. - */ - public void testReplace_countStarWithSampling() { + public void testReplace_count() { Alias count = countAlias(Literal.keyword(Source.EMPTY, "*")); SampledAggregateExec sampledAgg = sampledAggregate(esQueryExec(), List.of(count), List.of(), AggregatorMode.INITIAL, 0.5); PhysicalPlan result = applyRule(sampledAgg); - AggregateExec aggExec = assertAggregate(result, sampledAgg); + assertThat(result, instanceOf(ProjectExec.class)); + ProjectExec project = (ProjectExec) result; + assertThat(project.child(), instanceOf(EvalExec.class)); + EvalExec eval = (EvalExec) project.child(); + // COUNT and its bucket must be sample-corrected. + assertThat(eval.fields(), hasSize(2)); + AggregateExec aggExec = assertAggregate(eval.child(), sampledAgg); assertThat(aggExec.child(), instanceOf(SampleExec.class)); SampleExec sampleExec = (SampleExec) aggExec.child(); assertThat(sampleExec.probability(), is(sampledAgg.sampleProbability())); assertThat(sampleExec.child(), instanceOf(EsQueryExec.class)); } - /** - * SUM(salary), BY dept, INITIAL mode, prob=1.0. - * No SampleExec when probability is 1.0. - */ - public void testReplace_sumWithGroupingNoSampling() { + public void testReplace_sumWithGrouping_noSampling() { FieldAttribute salary = fieldAttribute("salary", DataType.INTEGER); Alias sum = new Alias(Source.EMPTY, "sum", new Sum(Source.EMPTY, salary)); FieldAttribute dept = fieldAttribute("dept", DataType.KEYWORD); SampledAggregateExec sampledAgg = sampledAggregate(esQueryExec(), List.of(sum), List.of(dept), AggregatorMode.INITIAL, 1.0); PhysicalPlan result = applyRule(sampledAgg); - AggregateExec aggExec = assertAggregate(result, sampledAgg); + AggregateExec aggExec = assertAggregate(result, sampledAgg); assertThat(aggExec.child(), instanceOf(EsQueryExec.class)); } - /** - * COUNT(emp_no), no groupings, FINAL mode, prob=0.3. - * Verifies non-INITIAL mode is preserved. - */ - public void testReplace_countFieldFinalMode() { + public void testReplace_countAndStddev_finalMode() { FieldAttribute empNo = fieldAttribute("emp_no", DataType.INTEGER); Alias count = countAlias(empNo); - SampledAggregateExec sampledAgg = sampledAggregate(esQueryExec(), List.of(count), List.of(), AggregatorMode.FINAL, 0.3); + Alias stddev = new Alias(Source.EMPTY, "stddev", new StdDev(Source.EMPTY, empNo)); + SampledAggregateExec sampledAgg = sampledAggregate(esQueryExec(), List.of(count, stddev), List.of(), AggregatorMode.FINAL, 0.3); PhysicalPlan result = applyRule(sampledAgg); - AggregateExec aggExec = assertAggregate(result, sampledAgg); + assertThat(result, instanceOf(ProjectExec.class)); + ProjectExec project = (ProjectExec) result; + assertThat(project.child(), instanceOf(EvalExec.class)); + EvalExec eval = (EvalExec) project.child(); + // COUNT and its bucket must be sample-corrected. + assertThat(eval.fields(), hasSize(2)); + AggregateExec aggExec = assertAggregate(eval.child(), sampledAgg); assertThat(aggExec.child(), instanceOf(SampleExec.class)); SampleExec sampleExec = (SampleExec) aggExec.child(); assertThat(sampleExec.probability(), is(sampledAgg.sampleProbability())); } - /** - * COUNT(*) + SUM(salary), BY dept, INITIAL mode, prob=0.5. - * SampleExec should wrap the leaf, not the intermediate EvalExec. - */ - public void testReplace_multipleAggsWithEvalChild() { + public void testReplace_countAndSum() { Alias count = countAlias(Literal.keyword(Source.EMPTY, "*")); FieldAttribute salary = fieldAttribute("salary", DataType.INTEGER); Alias sum = new Alias(Source.EMPTY, "sum", new Sum(Source.EMPTY, salary)); @@ -106,8 +106,14 @@ public void testReplace_multipleAggsWithEvalChild() { SampledAggregateExec sampledAgg = sampledAggregate(evalExec, List.of(count, sum), List.of(dept), AggregatorMode.INITIAL, 0.5); PhysicalPlan result = applyRule(sampledAgg); - AggregateExec aggExec = assertAggregate(result, sampledAgg); + assertThat(result, instanceOf(ProjectExec.class)); + ProjectExec project = (ProjectExec) result; + assertThat(project.child(), instanceOf(EvalExec.class)); + EvalExec eval = (EvalExec) project.child(); + // COUNT and SUM and their buckets must be sample-corrected. + assertThat(eval.fields(), hasSize(4)); + AggregateExec aggExec = assertAggregate(eval.child(), sampledAgg); assertThat(aggExec.child(), instanceOf(EvalExec.class)); EvalExec resultEval = (EvalExec) aggExec.child(); assertThat(resultEval.child(), instanceOf(SampleExec.class)); @@ -116,17 +122,33 @@ public void testReplace_multipleAggsWithEvalChild() { assertThat(sampleExec.child(), instanceOf(EsQueryExec.class)); } + public void testReplace_stdDev() { + FieldAttribute empNo = fieldAttribute("emp_no", DataType.INTEGER); + Alias stddev = new Alias(Source.EMPTY, "stddev", new StdDev(Source.EMPTY, empNo)); + SampledAggregateExec sampledAgg = sampledAggregate(esQueryExec(), List.of(stddev), List.of(), AggregatorMode.FINAL, 0.3); + + PhysicalPlan result = applyRule(sampledAgg); + + AggregateExec aggExec = assertAggregate(result, sampledAgg); + assertThat(aggExec.child(), instanceOf(SampleExec.class)); + SampleExec sampleExec = (SampleExec) aggExec.child(); + assertThat(sampleExec.probability(), is(sampledAgg.sampleProbability())); + } + private static PhysicalPlan applyRule(SampledAggregateExec sampledAgg) { return new ReplaceSampledStatsBySampleAndStats().apply(sampledAgg); } - private static AggregateExec assertAggregate(PhysicalPlan result, SampledAggregateExec sampledAgg) { - assertThat(result, instanceOf(AggregateExec.class)); - AggregateExec aggExec = (AggregateExec) result; + private static AggregateExec assertAggregate(PhysicalPlan plan, SampledAggregateExec sampledAgg) { + assertThat(plan, instanceOf(AggregateExec.class)); + AggregateExec aggExec = (AggregateExec) plan; assertThat(aggExec.aggregates(), is(sampledAgg.aggregates())); assertThat(aggExec.groupings(), is(sampledAgg.groupings())); assertThat(aggExec.getMode(), is(sampledAgg.getMode())); - assertThat(aggExec.intermediateAttributes(), is(sampledAgg.intermediateAttributes())); + assertThat(aggExec.intermediateAttributes(), hasSize(sampledAgg.intermediateAttributes().size())); + for (int i = 0; i < aggExec.intermediateAttributes().size(); i++) { + assertThat(aggExec.intermediateAttributes().get(i).children(), is(sampledAgg.intermediateAttributes().get(i).children())); + } return aggExec; } @@ -139,7 +161,13 @@ private static SampledAggregateExec sampledAggregate( ) { ArrayList allAggregates = new ArrayList<>(originalAggregates); for (NamedExpression agg : originalAggregates) { - allAggregates.add(new Alias(Source.EMPTY, agg.name() + "_bucket", agg.toAttribute())); + allAggregates.add( + new Alias( + Source.EMPTY, + agg.name() + "_bucket", + ((AggregateFunction) agg.children().getFirst()).withFilter(new Equals(Source.EMPTY, Literal.TRUE, Literal.TRUE)) + ) + ); } return new SampledAggregateExec( @@ -173,8 +201,4 @@ private static FieldAttribute fieldAttribute(String name, DataType type) { new EsField(name, type, new HashMap<>(), true, EsField.TimeSeriesFieldType.NONE) ); } - - private static ReferenceAttribute bucketAttribute() { - return new ReferenceAttribute(Source.EMPTY, null, ApproximationPlan.BUCKET_ID_COLUMN_NAME, DataType.INTEGER); - } }