diff --git a/docs/user/ppl/functions/math.rst b/docs/user/ppl/functions/math.rst index 2944f3fb8dc..b2b0dd47415 100644 --- a/docs/user/ppl/functions/math.rst +++ b/docs/user/ppl/functions/math.rst @@ -132,6 +132,84 @@ Example:: +--------------+ +SUM +--- + +Description +>>>>>>>>>>> + +Usage: sum(x, y, ...) calculates the sum of all provided arguments. This function accepts a variable number of arguments. + +Note: This function is only available in the eval command context and is rewritten to arithmetic addition while query parsing. + +Argument type: Variable number of INTEGER/LONG/FLOAT/DOUBLE arguments + +Return type: Wider number type among all arguments + +Example:: + + os> source=accounts | eval `SUM(1, 2, 3)` = SUM(1, 2, 3) | fields `SUM(1, 2, 3)` + fetched rows / total rows = 4/4 + +--------------+ + | SUM(1, 2, 3) | + |--------------| + | 6 | + | 6 | + | 6 | + | 6 | + +--------------+ + + os> source=accounts | eval total = SUM(age, 10, 5) | fields age, total + fetched rows / total rows = 4/4 + +-----+-------+ + | age | total | + |-----+-------| + | 32 | 47 | + | 36 | 51 | + | 28 | 43 | + | 33 | 48 | + +-----+-------+ + + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: avg(x, y, ...) calculates the average (arithmetic mean) of all provided arguments. This function accepts a variable number of arguments. + +Note: This function is only available in the eval command context and is rewritten to arithmetic expression (sum / count) at query parsing time. + +Argument type: Variable number of INTEGER/LONG/FLOAT/DOUBLE arguments + +Return type: DOUBLE + +Example:: + + os> source=accounts | eval `AVG(1, 2, 3)` = AVG(1, 2, 3) | fields `AVG(1, 2, 3)` + fetched rows / total rows = 4/4 + +--------------+ + | AVG(1, 2, 3) | + |--------------| + | 2.0 | + | 2.0 | + | 2.0 | + | 2.0 | + +--------------+ + + os> source=accounts | eval average = AVG(age, 30) | fields age, average + fetched rows / total rows = 4/4 + +-----+---------+ + | age | average | + |-----+---------| + | 32 | 31.0 | + | 36 | 33.0 | + | 28 | 29.0 | + | 33 | 31.5 | + +-----+---------+ + + ACOS ---- diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/MathematicalFunctionIT.java index f53049fb5c3..6cd5063b5a0 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/MathematicalFunctionIT.java @@ -547,4 +547,225 @@ public void testRint() throws IOException { verifySchema(result, schema("f", null, "double")); verifySome(result.getJSONArray("datarows"), rows(Math.rint(1.7))); } + + // SUM function tests for eval command + @Test + public void testEvalSumSingleInteger() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = sum(42) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + verifyDataRows(result, rows(42), rows(42), rows(42), rows(42), rows(42)); + } + + @Test + public void testEvalSumMultipleIntegers() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(1, 2, 3) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + verifyDataRows(result, rows(6), rows(6), rows(6), rows(6), rows(6)); + } + + @Test + public void testEvalSumMixedNumericTypes() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = sum(1, 2.5) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(3.5), rows(3.5), rows(3.5), rows(3.5), rows(3.5)); + } + + @Test + public void testEvalSumWithFields() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(age, 10) | fields f | head 7", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + verifyDataRows(result, rows(42), rows(46), rows(38), rows(43), rows(46), rows(49), rows(44)); + } + + @Test + public void testEvalSumMultipleNumericArguments() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(1, 2, 3, 4, 5) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + verifyDataRows(result, rows(15), rows(15), rows(15), rows(15), rows(15)); + } + + // AVG function tests for eval command + @Test + public void testEvalAvgSingleInteger() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = avg(42) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(42.0), rows(42.0), rows(42.0), rows(42.0), rows(42.0)); + } + + @Test + public void testEvalAvgMultipleIntegers() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = avg(1, 2, 3) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(2.0), rows(2.0), rows(2.0), rows(2.0), rows(2.0)); + } + + @Test + public void testEvalAvgTwoIntegers() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = avg(1, 2) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(1.5), rows(1.5), rows(1.5), rows(1.5), rows(1.5)); + } + + @Test + public void testEvalAvgMixedNumericTypes() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = avg(1, 2.5) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(1.75), rows(1.75), rows(1.75), rows(1.75), rows(1.75)); + } + + @Test + public void testEvalAvgWithFields() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = avg(age, 10) | fields f | head 7", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows( + result, rows(21.0), rows(23.0), rows(19.0), rows(21.5), rows(23.0), rows(24.5), rows(22.0)); + } + + @Test + public void testEvalAvgMultipleValues() throws IOException { + JSONObject result = + executeQuery( + String.format("source=%s | eval f = avg(1, 4) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(2.5), rows(2.5), rows(2.5), rows(2.5), rows(2.5)); + } + + @Test + public void testEvalAvgFiveValues() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = avg(1, 2, 3, 4, 5) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(3.0), rows(3.0), rows(3.0), rows(3.0), rows(3.0)); + } + + // Combined sum and avg tests + @Test + public void testEvalSumAndAvgComparison() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval sum_val = sum(10, 20, 30), avg_val = avg(10, 20, 30) | fields" + + " sum_val, avg_val | head 5", + TEST_INDEX_BANK)); + verifySchema(result, schema("sum_val", null, "int"), schema("avg_val", null, "double")); + verifyDataRows( + result, rows(60, 20.0), rows(60, 20.0), rows(60, 20.0), rows(60, 20.0), rows(60, 20.0)); + } + + @Test + public void testEvalSumInWhereClause() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | where sum(age, 10) > 40 | eval f = sum(age, 10) | fields f | head 6", + TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + // Should return rows where age + 10 > 40, so age > 30 + verifyDataRows(result, rows(42), rows(46), rows(43), rows(46), rows(49), rows(44)); + } + + @Test + public void testEvalAvgInWhereClause() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | where avg(age, 10) > 20 | eval f = avg(age, 10) | fields f | head 6", + TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + // Should return rows where (age + 10) / 2 > 20, so age > 30 + verifyDataRows(result, rows(21.0), rows(23.0), rows(21.5), rows(23.0), rows(24.5), rows(22.0)); + } + + @Test + public void testEvalComplexExpression() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(age, 5) + avg(10, 20) | fields f | head 5", + TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + // sum(age, 5) + avg(10, 20) = (age + 5) + 15.0 + verifyDataRows(result, rows(52.0), rows(56.0), rows(48.0), rows(53.0), rows(56.0)); + } + + @Test + public void testEvalNestedSumAvg() throws IOException { + // Note: This tests the arithmetic expression rewriting, not actual nested function calls + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(avg(20, 30), 10) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + // avg(20, 30) = 25.0, sum(25.0, 10) = 35.0 + verifyDataRows(result, rows(35.0), rows(35.0), rows(35.0), rows(35.0), rows(35.0)); + } + + @Test + public void testEvalSumWithMultipleFields() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(age, age, 10) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + // sum(age, age, 10) = age + age + 10 = 2*age + 10 + verifyDataRows(result, rows(74), rows(82), rows(66), rows(76), rows(82)); + } + + @Test + public void testEvalAvgWithExpression() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = avg(age * 2, 10) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + // avg(age * 2, 10) = (age * 2 + 10) / 2 = age + 5 + verifyDataRows(result, rows(37.0), rows(41.0), rows(33.0), rows(38.0), rows(41.0)); + } + + @Test + public void testEvalSumWithNegativeNumbers() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sum(-5, 10, -3) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "int")); + verifyDataRows(result, rows(2), rows(2), rows(2), rows(2), rows(2)); + } + + @Test + public void testEvalAvgWithNegativeNumbers() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = avg(-10, 10) | fields f | head 5", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, rows(0.0), rows(0.0), rows(0.0), rows(0.0), rows(0.0)); + } } diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 2699d7c31e1..2221a36da43 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -710,6 +710,8 @@ mathematicalFunctionName | TRUNCATE | RINT | SIGNUM + | SUM + | AVG | trigonometricFunctionName ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 1170123f69d..5e865fe7333 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -5,9 +5,7 @@ package org.opensearch.sql.ppl.parser; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.*; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; @@ -271,9 +269,16 @@ public UnresolvedExpression visitCaseFunctionCall( @Override public UnresolvedExpression visitEvalFunctionCall(EvalFunctionCallContext ctx) { final String functionName = ctx.evalFunctionName().getText(); - return buildFunction( - FUNCTION_NAME_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), functionName), - ctx.functionArgs().functionArg()); + final String mappedName = + FUNCTION_NAME_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), functionName); + + // Rewrite sum and avg functions to arithmetic expressions + if (SUM.getName().getFunctionName().equalsIgnoreCase(mappedName) + || AVG.getName().getFunctionName().equalsIgnoreCase(mappedName)) { + return rewriteSumAvgFunction(mappedName, ctx.functionArgs().functionArg()); + } + + return buildFunction(mappedName, ctx.functionArgs().functionArg()); } /** Cast function. */ @@ -293,6 +298,55 @@ private Function buildFunction( functionName, args.stream().map(this::visitFunctionArg).collect(Collectors.toList())); } + /** + * Rewrites sum(a, b, c, ...) to (a + b + c + ...) and avg(a, b, c, ...) to (a + b + c + ...) / n + * Uses balanced tree construction to avoid deep recursion with large argument lists. + */ + private UnresolvedExpression rewriteSumAvgFunction( + String functionName, List args) { + if (args.isEmpty()) { + throw new SyntaxCheckException(functionName + " function requires at least one argument"); + } + + List arguments = + args.stream().map(this::visitFunctionArg).collect(Collectors.toList()); + + // Build the sum expression as a balanced tree to avoid deep recursion + UnresolvedExpression functionExpr = buildBalancedTree("+", arguments); + + // For avg, divide by the count of arguments + if (AVG.getName().getFunctionName().equalsIgnoreCase(functionName)) { + UnresolvedExpression count = AstDSL.doubleLiteral((double) arguments.size()); + functionExpr = new Function("/", Arrays.asList(functionExpr, count)); + } + + return functionExpr; + } + + /** + * Builds a balanced tree of binary operations to avoid deep recursion. For example, [a, b, c, d] + * becomes ((a + b) + (c + d)) instead of (((a + b) + c) + d). This ensures recursion depth is + * O(log n) instead of O(n). + */ + private UnresolvedExpression buildBalancedTree( + String operator, List expressions) { + if (expressions.size() == 1) { + return expressions.get(0); + } + + if (expressions.size() == 2) { + return new Function(operator, Arrays.asList(expressions.get(0), expressions.get(1))); + } + + // Split the list in half and recursively build balanced subtrees + int mid = expressions.size() / 2; + UnresolvedExpression left = buildBalancedTree(operator, expressions.subList(0, mid)); + UnresolvedExpression right = + buildBalancedTree(operator, expressions.subList(mid, expressions.size())); + + return new Function(operator, Arrays.asList(left, right)); + } + @Override public UnresolvedExpression visitSingleFieldRelevanceFunction( SingleFieldRelevanceFunctionContext ctx) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java index db759a7b9af..e09b62b748a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java @@ -98,6 +98,73 @@ public void testEval3() { verifyPPLToSparkSQL(root, expectedSparkSql); } + @Test + public void testEvalSum() { + String ppl = "source=EMP | eval total = sum(1, 2, 3) | fields EMPNO, total"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], total=[+(1, +(2, 3))])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = "SELECT `EMPNO`, 1 + (2 + 3) `total`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testEvalSumWithFields() { + String ppl = "source=EMP | eval total = sum(SAL, COMM, 100) | fields EMPNO, total"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], total=[+($5, +($6, 100))])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `SAL` + (`COMM` + 100) `total`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testEvalAvg() { + String ppl = "source=EMP | eval average = avg(10, 20, 30) | fields EMPNO, average"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], average=[DIVIDE(+(10, +(20, 30)), 3.0E0:DOUBLE)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `DIVIDE`(10 + (20 + 30), 3.0E0) `average`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testEvalAvgWithFields() { + String ppl = "source=EMP | eval avgSal = avg(SAL, COMM) | fields EMPNO, avgSal"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], avgSal=[DIVIDE(+($5, $6), 2.0E0:DOUBLE)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `DIVIDE`(`SAL` + `COMM`, 2.0E0) `avgSal`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testEvalSumSingleArg() { + String ppl = "source=EMP | eval total = sum(42) | fields EMPNO, total"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], total=[42])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = "SELECT `EMPNO`, 42 `total`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + @Test public void testEvalWithSort() { String ppl = "source=EMP | eval a = EMPNO | sort - a | fields a"; diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index ca0f508823f..52c1c8885fa 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -604,6 +604,174 @@ public void testDataTypeFuncCall() { eval(relation("t"), let(field("f"), cast(intLiteral(1), stringLiteral("string"))))); } + @Test + public void testEvalSumFunctionSingleArg() { + // sum(42) -> 42 + assertEqual("source=t | eval f=sum(42)", eval(relation("t"), let(field("f"), intLiteral(42)))); + } + + @Test + public void testEvalSumFunctionMultipleArgs() { + // sum(1, 2, 3) -> (1 + (2 + 3)) - balanced tree + assertEqual( + "source=t | eval f=sum(1, 2, 3)", + eval( + relation("t"), + let( + field("f"), + function("+", intLiteral(1), function("+", intLiteral(2), intLiteral(3)))))); + } + + @Test + public void testEvalSumFunctionWithFields() { + // sum(a, b, 10) -> (a + (b + 10)) - balanced tree + assertEqual( + "source=t | eval f=sum(a, b, 10)", + eval( + relation("t"), + let(field("f"), function("+", field("a"), function("+", field("b"), intLiteral(10)))))); + } + + @Test + public void testEvalSumFunctionFourArgs() { + // sum(1, 2, 3, 4) -> ((1 + 2) + (3 + 4)) - balanced tree + assertEqual( + "source=t | eval f=sum(1, 2, 3, 4)", + eval( + relation("t"), + let( + field("f"), + function( + "+", + function("+", intLiteral(1), intLiteral(2)), + function("+", intLiteral(3), intLiteral(4)))))); + } + + @Test + public void testEvalSumFunctionMixedTypes() { + // sum(1, 2.5) -> (1 + 2.5) + assertEqual( + "source=t | eval f=sum(1, 2.5)", + eval(relation("t"), let(field("f"), function("+", intLiteral(1), decimalLiteral(2.5))))); + } + + @Test + public void testEvalAvgFunctionSingleArg() { + // avg(42) -> 42 / 1.0 + assertEqual( + "source=t | eval f=avg(42)", + eval(relation("t"), let(field("f"), function("/", intLiteral(42), doubleLiteral(1.0))))); + } + + @Test + public void testEvalAvgFunctionTwoArgs() { + // avg(10, 20) -> (10 + 20) / 2.0 + assertEqual( + "source=t | eval f=avg(10, 20)", + eval( + relation("t"), + let( + field("f"), + function("/", function("+", intLiteral(10), intLiteral(20)), doubleLiteral(2.0))))); + } + + @Test + public void testEvalAvgFunctionMultipleArgs() { + // avg(1, 2, 3) -> (1 + (2 + 3)) / 3.0 - balanced tree + assertEqual( + "source=t | eval f=avg(1, 2, 3)", + eval( + relation("t"), + let( + field("f"), + function( + "/", + function("+", intLiteral(1), function("+", intLiteral(2), intLiteral(3))), + doubleLiteral(3.0))))); + } + + @Test + public void testEvalAvgFunctionWithFields() { + // avg(a, b) -> (a + b) / 2.0 + assertEqual( + "source=t | eval f=avg(a, b)", + eval( + relation("t"), + let( + field("f"), + function("/", function("+", field("a"), field("b")), doubleLiteral(2.0))))); + } + + @Test + public void testEvalAvgFunctionMixedTypes() { + // avg(1, 2.5, 3) -> (1 + (2.5 + 3)) / 3.0 - balanced tree + assertEqual( + "source=t | eval f=avg(1, 2.5, 3)", + eval( + relation("t"), + let( + field("f"), + function( + "/", + function("+", intLiteral(1), function("+", decimalLiteral(2.5), intLiteral(3))), + doubleLiteral(3.0))))); + } + + @Test + public void testEvalComplexExpressionWithSumAndAvg() { + // sum(a, 5) + avg(10, 20) -> (a + 5) + ((10 + 20) / 2.0) + assertEqual( + "source=t | eval f=sum(a, 5) + avg(10, 20)", + eval( + relation("t"), + let( + field("f"), + function( + "+", + function("+", field("a"), intLiteral(5)), + function( + "/", function("+", intLiteral(10), intLiteral(20)), doubleLiteral(2.0)))))); + } + + @Test + public void testWhereSumFunction() { + // where sum(a, 10) > 20 -> where (a + 10) > 20 + assertEqual( + "source=t | where sum(a, 10) > 20", + filter( + relation("t"), + compare(">", function("+", field("a"), intLiteral(10)), intLiteral(20)))); + } + + @Test + public void testWhereAvgFunction() { + // where avg(a, b) < 15.5 -> where (a + b) / 2.0 < 15.5 + assertEqual( + "source=t | where avg(a, b) < 15.5", + filter( + relation("t"), + compare( + "<", + function("/", function("+", field("a"), field("b")), doubleLiteral(2.0)), + decimalLiteral(15.5)))); + } + + @Test + public void testWhereSumAndAvgComparison() { + // where sum(a, b) > avg(10, 20, 30) -> where (a + b) > (10 + (20 + 30)) / 3.0 - balanced tree + assertEqual( + "source=t | where sum(a, b) > avg(10, 20, 30)", + filter( + relation("t"), + compare( + ">", + function("+", field("a"), field("b")), + function( + "/", + function("+", intLiteral(10), function("+", intLiteral(20), intLiteral(30))), + doubleLiteral(3.0))))); + } + @Test public void testNestedFieldName() { assertEqual(