Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions core/src/main/java/org/opensearch/sql/ast/tree/Timechart.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,43 @@

package org.opensearch.sql.ast.tree;

import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral;
import static org.opensearch.sql.ast.dsl.AstDSL.eval;
import static org.opensearch.sql.ast.dsl.AstDSL.function;
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;
import static org.opensearch.sql.ast.expression.IntervalUnit.SECOND;
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.sum;
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.timestampadd;
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.timestampdiff;
import static org.opensearch.sql.calcite.plan.OpenSearchConstants.IMPLICIT_FIELD_TIMESTAMP;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.calcite.utils.PlanUtils;

/** AST node represent Timechart operation. */
@Getter
Expand Down Expand Up @@ -49,8 +78,9 @@ public Timechart useOther(Boolean useOther) {
}

@Override
public Timechart attach(UnresolvedPlan child) {
return toBuilder().child(child).build();
public UnresolvedPlan attach(UnresolvedPlan child) {
// Transform after child attached to avoid unintentionally overriding it
return toBuilder().child(child).build().transformPerFunction();
}

@Override
Expand All @@ -62,4 +92,112 @@ public List<UnresolvedPlan> getChild() {
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitTimechart(this, context);
}

/**
* Transform per function to eval-based post-processing on sum result by timechart. Specifically,
* calculate how many seconds are in the time bucket based on the span option dynamically, then
* divide the aggregated sum value by the number of seconds to get the per-second rate.
*
* <p>For example, with span=5m per_second(field): per second rate = sum(field) / 300 seconds
*
* @return eval+timechart if per function present, or the original timechart otherwise.
*/
private UnresolvedPlan transformPerFunction() {
Optional<PerFunction> perFuncOpt = PerFunction.from(aggregateFunction);
if (perFuncOpt.isEmpty()) {
return this;
}

PerFunction perFunc = perFuncOpt.get();
Span span = (Span) this.binExpression;
Field spanStartTime = AstDSL.field(IMPLICIT_FIELD_TIMESTAMP);
Function spanEndTime = timestampadd(span.getUnit(), span.getValue(), spanStartTime);
Function spanSeconds = timestampdiff(SECOND, spanStartTime, spanEndTime);

return eval(
timechart(AstDSL.alias(perFunc.aggName, sum(perFunc.aggArg))),
let(perFunc.aggName).multiply(perFunc.seconds).dividedBy(spanSeconds));
}

private Timechart timechart(UnresolvedExpression newAggregateFunction) {
return this.toBuilder().aggregateFunction(newAggregateFunction).build();
}

/** TODO: extend to support additional per_* functions */
@RequiredArgsConstructor
static class PerFunction {
private static final Map<String, Integer> UNIT_SECONDS = Map.of("per_second", 1);
private final String aggName;
private final UnresolvedExpression aggArg;
private final int seconds;

static Optional<PerFunction> from(UnresolvedExpression aggExpr) {
if (!(aggExpr instanceof AggregateFunction)) {
return Optional.empty();
}

AggregateFunction aggFunc = (AggregateFunction) aggExpr;
String aggFuncName = aggFunc.getFuncName().toLowerCase(Locale.ROOT);
if (!UNIT_SECONDS.containsKey(aggFuncName)) {
return Optional.empty();
}

String aggName = toAggName(aggFunc);
return Optional.of(
new PerFunction(aggName, aggFunc.getField(), UNIT_SECONDS.get(aggFuncName)));
}

private static String toAggName(AggregateFunction aggFunc) {
String fieldName =
(aggFunc.getField() instanceof Field)
? ((Field) aggFunc.getField()).getField().toString()
: aggFunc.getField().toString();
return String.format(Locale.ROOT, "%s(%s)", aggFunc.getFuncName(), fieldName);
}
}

private PerFunctionRateExprBuilder let(String fieldName) {
return new PerFunctionRateExprBuilder(AstDSL.field(fieldName));
}

/** Fluent builder for creating Let expressions with mathematical operations. */
static class PerFunctionRateExprBuilder {
private final Field field;
private UnresolvedExpression expr;

PerFunctionRateExprBuilder(Field field) {
this.field = field;
this.expr = field;
}

PerFunctionRateExprBuilder multiply(Integer multiplier) {
// Promote to double literal to avoid integer division in downstream
this.expr =
function(
MULTIPLY.getName().getFunctionName(), expr, doubleLiteral(multiplier.doubleValue()));
return this;
}

Let dividedBy(UnresolvedExpression divisor) {
return AstDSL.let(field, function(DIVIDE.getName().getFunctionName(), expr, divisor));
}

static UnresolvedExpression sum(UnresolvedExpression field) {
return aggregate(SUM.getName().getFunctionName(), field);
}

static Function timestampadd(
SpanUnit unit, UnresolvedExpression value, UnresolvedExpression timestampField) {
UnresolvedExpression intervalUnit =
stringLiteral(PlanUtils.spanUnitToIntervalUnit(unit).toString());
return function(
TIMESTAMPADD.getName().getFunctionName(), intervalUnit, value, timestampField);
}

static Function timestampdiff(
IntervalUnit unit, UnresolvedExpression start, UnresolvedExpression end) {
return function(
TIMESTAMPDIFF.getName().getFunctionName(), stringLiteral(unit.toString()), start, end);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,9 @@ public RelNode visitFlatten(Flatten node, CalcitePlanContext context) {

/** Helper method to get the function name for proper column naming */
private String getValueFunctionName(UnresolvedExpression aggregateFunction) {
if (aggregateFunction instanceof Alias) {
return ((Alias) aggregateFunction).getName();
}
if (!(aggregateFunction instanceof AggregateFunction)) {
return "value";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,59 @@ static SpanUnit intervalUnitToSpanUnit(IntervalUnit unit) {

}

static IntervalUnit spanUnitToIntervalUnit(SpanUnit unit) {
switch (unit) {
case MILLISECOND:
case MS:
return IntervalUnit.MICROSECOND;
case SECOND:
case SECONDS:
case SEC:
case SECS:
case S:
return IntervalUnit.SECOND;
case MINUTE:
case MINUTES:
case MIN:
case MINS:
case m:
return IntervalUnit.MINUTE;
case HOUR:
case HOURS:
case HR:
case HRS:
case H:
return IntervalUnit.HOUR;
case DAY:
case DAYS:
case D:
return IntervalUnit.DAY;
case WEEK:
case WEEKS:
case W:
return IntervalUnit.WEEK;
case MONTH:
case MONTHS:
case MON:
case M:
return IntervalUnit.MONTH;
case QUARTER:
case QUARTERS:
case QTR:
case QTRS:
case Q:
return IntervalUnit.QUARTER;
case YEAR:
case YEARS:
case Y:
return IntervalUnit.YEAR;
case UNKNOWN:
return IntervalUnit.UNKNOWN;
default:
throw new UnsupportedOperationException("Unsupported span unit: " + unit);
}
}

static RexNode makeOver(
CalcitePlanContext context,
BuiltinFunctionName functionName,
Expand Down
155 changes: 155 additions & 0 deletions core/src/test/java/org/opensearch/sql/ast/tree/TimechartTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
import static org.opensearch.sql.ast.dsl.AstDSL.alias;
import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral;
import static org.opensearch.sql.ast.dsl.AstDSL.field;
import static org.opensearch.sql.ast.dsl.AstDSL.function;
import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral;
import static org.opensearch.sql.ast.dsl.AstDSL.relation;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

class TimechartTest {

@ParameterizedTest
@CsvSource({"1, m, MINUTE", "30, s, SECOND", "5, m, MINUTE", "2, h, HOUR", "1, d, DAY"})
void should_transform_per_second_for_different_spans(
int spanValue, String spanUnit, String expectedIntervalUnit) {
withTimechart(span(spanValue, spanUnit), perSecond("bytes"))
.whenTransformingPerFunction()
.thenExpect(
eval(
let(
"per_second(bytes)",
divide(
multiply("per_second(bytes)", 1.0),
timestampdiff(
"SECOND",
"@timestamp",
timestampadd(expectedIntervalUnit, spanValue, "@timestamp")))),
timechart(span(spanValue, spanUnit), alias("per_second(bytes)", sum("bytes")))));
}

@Test
void should_not_transform_non_per_functions() {
withTimechart(span(1, "m"), sum("bytes"))
.whenTransformingPerFunction()
.thenExpect(timechart(span(1, "m"), sum("bytes")));
}

@Test
void should_preserve_all_fields_during_per_function_transformation() {
Timechart original =
new Timechart(relation("logs"), perSecond("bytes"))
.span(span(5, "m"))
.by(field("status"))
.limit(20)
.useOther(false);

Timechart expected =
new Timechart(relation("logs"), alias("per_second(bytes)", sum("bytes")))
.span(span(5, "m"))
.by(field("status"))
.limit(20)
.useOther(false);

withTimechart(original)
.whenTransformingPerFunction()
.thenExpect(
eval(
let(
"per_second(bytes)",
divide(
multiply("per_second(bytes)", 1.0),
timestampdiff(
"SECOND", "@timestamp", timestampadd("MINUTE", 5, "@timestamp")))),
expected));
}

// Fluent API for readable test assertions

private static TransformationAssertion withTimechart(Span spanExpr, AggregateFunction aggFunc) {
return new TransformationAssertion(timechart(spanExpr, aggFunc));
}

private static TransformationAssertion withTimechart(Timechart timechart) {
return new TransformationAssertion(timechart);
}

private static Timechart timechart(Span spanExpr, UnresolvedExpression aggExpr) {
// Set child here because expected object won't call attach below
return new Timechart(relation("t"), aggExpr).span(spanExpr).limit(10).useOther(true);
}

private static Span span(int value, String unit) {
return AstDSL.span(field("@timestamp"), intLiteral(value), SpanUnit.of(unit));
}

private static AggregateFunction perSecond(String fieldName) {
return (AggregateFunction) aggregate("per_second", field(fieldName));
}

private static AggregateFunction sum(String fieldName) {
return (AggregateFunction) aggregate("sum", field(fieldName));
}

private static Let let(String fieldName, UnresolvedExpression expression) {
return AstDSL.let(field(fieldName), expression);
}

private static UnresolvedExpression multiply(String fieldName, double right) {
return function("*", field(fieldName), doubleLiteral(right));
}

private static UnresolvedExpression divide(
UnresolvedExpression left, UnresolvedExpression right) {
return function("/", left, right);
}

private static UnresolvedExpression timestampadd(String unit, int value, String timestampField) {
return function(
"timestampadd", AstDSL.stringLiteral(unit), intLiteral(value), field(timestampField));
}

private static UnresolvedExpression timestampdiff(
String unit, String startField, UnresolvedExpression end) {
return function("timestampdiff", AstDSL.stringLiteral(unit), field(startField), end);
}

private static UnresolvedPlan eval(Let letExpr, Timechart timechartExpr) {
return AstDSL.eval(timechartExpr, letExpr);
}

private static class TransformationAssertion {
private final Timechart timechart;
private UnresolvedPlan result;

TransformationAssertion(Timechart timechart) {
this.timechart = timechart;
}

public TransformationAssertion whenTransformingPerFunction() {
this.result = timechart.attach(timechart.getChild().get(0));
return this;
}

public void thenExpect(UnresolvedPlan expected) {
assertEquals(expected, result);
}
}
}
Loading
Loading