diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/TimeSeriesGroupByAll.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/TimeSeriesGroupByAll.java index 9be4c8686c3b2..9991f739e40b9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/TimeSeriesGroupByAll.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/TimeSeriesGroupByAll.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.Functions; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction; @@ -39,33 +40,37 @@ public LogicalPlan apply(LogicalPlan logicalPlan) { } public LogicalPlan rule(TimeSeriesAggregate aggregate) { - AggregateFunction lastTSAggFunction = null; - AggregateFunction lastNonTSAggFunction = null; + Holder lastTSAggFunction = new Holder<>(); + Holder lastNonTSAggFunction = new Holder<>(); List newAggregateFunctions = new ArrayList<>(aggregate.aggregates().size()); for (NamedExpression agg : aggregate.aggregates()) { - if (agg instanceof Alias alias && alias.child() instanceof AggregateFunction af) { - if (af instanceof TimeSeriesAggregateFunction tsAgg) { - newAggregateFunctions.add(new Alias(alias.source(), alias.name(), new Values(tsAgg.source(), tsAgg))); - lastTSAggFunction = tsAgg; - } else { - newAggregateFunctions.add(agg); - lastNonTSAggFunction = af; - } - } else { - newAggregateFunctions.add(agg); + Holder newAggHolder = new Holder<>(agg); + if (agg instanceof Alias alias) { + alias.forEachDownMayReturnEarly((lp, exit) -> { + if (lp instanceof TimeSeriesAggregateFunction) { + // we've encountered a time-series aggregation function first, so we'll enable the "group by all" logic + newAggHolder.set(new Alias(alias.source(), alias.name(), new Values(alias.child().source(), alias.child()))); + lastTSAggFunction.set(agg); + exit.set(true); + } else if (lp instanceof AggregateFunction) { + lastNonTSAggFunction.set(agg); + exit.set(true); + } + }); } + newAggregateFunctions.add(newAggHolder.get()); } - if (lastTSAggFunction == null) { + if (lastTSAggFunction.get() == null) { return aggregate; } - if (lastNonTSAggFunction != null) { + if (lastNonTSAggFunction.get() != null) { throw new IllegalArgumentException( "Cannot mix time-series aggregate [" - + lastTSAggFunction.sourceText() + + lastTSAggFunction.get().sourceText() + "] and regular aggregate [" - + lastNonTSAggFunction.sourceText() + + lastNonTSAggFunction.get().sourceText() + "] in the same TimeSeriesAggregate." ); @@ -79,7 +84,7 @@ public LogicalPlan rule(TimeSeriesAggregate aggregate) { if (Functions.isGrouping(Alias.unwrap(grouping)) == false) { throw new IllegalArgumentException( "Only grouping functions are supported (e.g. tbucket) when the time series aggregation function [" - + lastTSAggFunction.sourceText() + + lastTSAggFunction.get().sourceText() + "] is not wrapped with another aggregation function. Found [" + grouping.sourceText() + "]."