diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 0bded5f12c2..da4fffabf61 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -57,6 +57,8 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -522,6 +524,16 @@ public LogicalPlan visitEval(Eval node, AnalysisContext context) { return new LogicalEval(child, expressionsBuilder.build()); } + @Override + public LogicalPlan visitAddTotals(AddTotals node, AnalysisContext context) { + throw getOnlyForCalciteException("addtotals"); + } + + @Override + public LogicalPlan visitAddColTotals(AddColTotals node, AnalysisContext context) { + throw getOnlyForCalciteException("addcoltotals"); + } + /** Build {@link ParseExpression} to context and skip to child nodes. */ @Override public LogicalPlan visitParse(Parse node, AnalysisContext context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index f9e6c295d0d..f835000cb38 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -45,6 +45,8 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -452,4 +454,12 @@ public T visitAppend(Append node, C context) { public T visitMultisearch(Multisearch node, C context) { return visitChildren(node, context); } + + public T visitAddTotals(AddTotals node, C context) { + return visitChildren(node, context); + } + + public T visitAddColTotals(AddColTotals node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java b/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java new file mode 100644 index 00000000000..6f999488b9c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.*; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Literal; + +/** + * AST node representing the PPL addcoltotals command. Computes column-wise totals across events and + * optionally appends a summary event. + * + * @see AddTotals for row-wise totals + */ +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class AddColTotals extends UnresolvedPlan { + private final List fieldList; + private final Map options; + private UnresolvedPlan child; + + @Override + public AddColTotals attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? ImmutableList.of() : ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitAddColTotals(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java b/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java new file mode 100644 index 00000000000..93ff5ccfbe7 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Literal; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class AddTotals extends UnresolvedPlan { + private final List fieldList; + private final Map options; + private UnresolvedPlan child; + + @Override + public AddTotals attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? ImmutableList.of() : ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitAddTotals(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index aaef0a00be2..489d155d841 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -99,6 +99,8 @@ import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -2496,6 +2498,282 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) { return sb.toString(); } + /** Transforms visitAddColTotals command into SQL-based operations. */ + @Override + public RelNode visitAddColTotals(AddColTotals node, CalcitePlanContext context) { + visitChildren(node, context); + + // Parse options from the AddTotals node + Map options = node.getOptions(); + String label = getOptionValue(options, "label", "Total"); + String labelField = getOptionValue(options, "labelfield", null); + // Determine which fields to aggregate + + // Handle row=true option: add a new field that sums all specified fields for each row + List fieldsToAggregate = node.getFieldList(); + return buildAddRowTotalAggregate( + context, fieldsToAggregate, false, true, null, labelField, label); + } + + /** + * Cast integer sum to long, real/float to double to avoid ClassCastException + * + * @param context + * @param fieldRef + * @param fieldDataType + * @return + */ + public RexNode getAggregateDataTypeFieldRef( + CalcitePlanContext context, RexNode fieldRef, RelDataTypeField fieldDataType) { + RexNode castFieldRef = fieldRef; + if (fieldDataType.getType().getSqlTypeName() == SqlTypeName.INTEGER) { + castFieldRef = context.relBuilder.cast(fieldRef, SqlTypeName.BIGINT); + } else if ((fieldDataType.getType().getSqlTypeName() == SqlTypeName.FLOAT) + || (fieldDataType.getType().getSqlTypeName() == SqlTypeName.REAL)) { + castFieldRef = context.relBuilder.cast(fieldRef, SqlTypeName.DOUBLE); + } + + return castFieldRef; + } + + public RelNode buildAddRowTotalAggregate( + CalcitePlanContext context, + List fieldsToAggregate, + boolean addTotalsForEachRow, + boolean addTotalsForEachColumn, + String newColTotalsFieldName, + String labelField, + String label) { + + // Build aggregation calls for totals calculation + boolean extraColTotalField = false; + RexNode sumExpression = null; + List aggCalls = new ArrayList<>(); + List fieldNameToSum = new ArrayList<>(); + RelNode originalData = context.relBuilder.peek(); + List fieldNames = originalData.getRowType().getFieldNames(); + boolean foundLabelField = false; + int labelLength = + (labelField != null) && (labelField.length() > label.length()) + ? labelField.length() + : label.length(); + + RelDataType labelVarcharType = + context.relBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR, labelLength); + + // If no specific fields specified, use all numeric fields + if (fieldsToAggregate.isEmpty()) { + fieldsToAggregate = getAllNumericFields(originalData, context); + } + List orginalDataProjectedFields = new ArrayList<>(); + List fieldsToSum = new ArrayList<>(); + java.util.List fieldList = + originalData.getRowType().getFieldList(); + for (RelDataTypeField fieldDataType : fieldList) { + RexNode fieldRef = context.relBuilder.field(fieldDataType.getName()); + boolean columnAddedToNewProject = false; + if (shouldAggregateField(fieldDataType.getName(), fieldsToAggregate)) { + + if (isNumericField(fieldRef, context)) { + fieldsToSum.add(fieldRef); + if (addTotalsForEachColumn) { + // Cast integer sum to long/double for int/float types to avoid ClassCastException + RexNode castFieldRef = getAggregateDataTypeFieldRef(context, fieldRef, fieldDataType); + orginalDataProjectedFields.add(castFieldRef); + columnAddedToNewProject = true; + + AggCall sumCall = context.relBuilder.sum(castFieldRef).as(fieldDataType.getName()); + aggCalls.add(sumCall); + } + fieldNameToSum.add(fieldDataType.getName()); + if (addTotalsForEachRow) { + // Use cast field for row totals to avoid ClassCastException + RexNode rowCastFieldRef = + getAggregateDataTypeFieldRef(context, fieldRef, fieldDataType); + + if (sumExpression == null) { + sumExpression = rowCastFieldRef; + } else { + sumExpression = + context.relBuilder.call( + org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS, + sumExpression, + rowCastFieldRef); + } + } + } + } + if (!columnAddedToNewProject) { + orginalDataProjectedFields.add(fieldRef); + } + if (addTotalsForEachColumn && fieldDataType.getName().equals(labelField)) { + // Use specified label field for the label + foundLabelField = true; + } + } + context.relBuilder.project(orginalDataProjectedFields, fieldNames); + if (addTotalsForEachRow && !fieldsToSum.isEmpty()) { + // Add the new column with the sum + context.relBuilder.projectPlus( + context.relBuilder.alias(sumExpression, newColTotalsFieldName)); + if (newColTotalsFieldName.equals(labelField)) { + foundLabelField = true; + } + } + if (addTotalsForEachColumn) { + if (!foundLabelField && (labelField != null)) { + context.relBuilder.projectPlus( + context.relBuilder.alias( + context.relBuilder.getRexBuilder().makeNullLiteral(labelVarcharType), labelField)); + extraColTotalField = true; + } + } + + originalData = context.relBuilder.build(); + context.relBuilder.push(originalData); + if (addTotalsForEachColumn) { + // Perform aggregation (no group by - single totals row) + context.relBuilder.aggregate( + context.relBuilder.groupKey(), // Empty group key for single totals row + aggCalls); + // 3. Build the totals row with proper field order and labels + List selectList = new ArrayList<>(); + + fieldList = originalData.getRowType().getFieldList(); + for (RelDataTypeField fieldDataType : fieldList) { + if (fieldNameToSum.contains(fieldDataType.getName())) { + selectList.add( + context.relBuilder.alias( + context.relBuilder.field(fieldDataType.getName()), fieldDataType.getName())); + + } else if (fieldDataType.getName().equals(labelField) + && (extraColTotalField + || fieldDataType.getType().getFamily() == SqlTypeFamily.CHARACTER)) { + // Use specified label field for the label - cast to match original field type + RexNode labelLiteral = + context.relBuilder.getRexBuilder().makeLiteral(label, fieldDataType.getType(), true); + selectList.add(context.relBuilder.alias(labelLiteral, fieldDataType.getName())); + + } else { + // Other fields get NULL in totals row - cast to match original field type + selectList.add( + context.relBuilder.alias( + context.relBuilder.getRexBuilder().makeNullLiteral(fieldDataType.getType()), + fieldDataType.getName())); + } + } + + // Project the totals row with proper field order and labels + context.relBuilder.project(selectList); + RelNode totalsRow = context.relBuilder.build(); + // 4. Union original data with totals row + context.relBuilder.push(originalData); + context.relBuilder.push(totalsRow); + context.relBuilder.union(true); // Use UNION ALL to preserve order + } + return context.relBuilder.peek(); + } + + /** Transforms visitAddTotals command into SQL-based operations. */ + @Override + public RelNode visitAddTotals(AddTotals node, CalcitePlanContext context) { + // 1. Process child plan first + visitChildren(node, context); + + // Parse options from the AddTotals node + Map options = node.getOptions(); + String label = + getOptionValue( + options, "label", "Total"); // when col=true , add summary event with this label + String labelField = + getOptionValue( + options, + "labelfield", + null); // when col=true , add summary event with this label field at the end of rows + String newColTotalsFieldName = + getOptionValue( + options, "fieldname", "Total"); // when row=true , add new field as new column + boolean addTotalsForEachRow = getBooleanOptionValue(options, "row", true); + boolean addTotalsForEachColumn = + getBooleanOptionValue(options, "col", false); // when col=true/false check + + // Determine which fields to aggregate + List fieldsToAggregate = node.getFieldList(); + + // Handle row=true option: add a new field that sums all specified fields for each row + return buildAddRowTotalAggregate( + context, + fieldsToAggregate, + addTotalsForEachRow, + addTotalsForEachColumn, + newColTotalsFieldName, + labelField, + label); + } + + private String getOptionValue(Map options, String key, String defaultValue) { + Literal literal = options.get(key); + if (literal == null) { + return defaultValue; + } + Object value = literal.getValue(); + if (value == null) { + return defaultValue; + } + return value.toString(); + } + + /** Helper method to extract boolean option values */ + private boolean getBooleanOptionValue( + Map options, String key, boolean defaultValue) { + if (options.containsKey(key)) { + Object value = options.get(key).getValue(); + if (value instanceof Boolean) { + return (Boolean) value; + } + if (value instanceof String) { + return Boolean.parseBoolean((String) value); + } + } + return defaultValue; + } + + /** Get all numeric fields from the RelNode */ + private List getAllNumericFields(RelNode relNode, CalcitePlanContext context) { + List numericFields = new ArrayList<>(); + for (String fieldName : relNode.getRowType().getFieldNames()) { + if (isNumericFieldName(fieldName, relNode)) { + numericFields.add( + new Field(new org.opensearch.sql.ast.expression.QualifiedName(fieldName))); + } + } + return numericFields; + } + + /** Check if a field should be aggregated based on the field list */ + private boolean shouldAggregateField(String fieldName, List fieldsToAggregate) { + if (fieldsToAggregate.isEmpty()) { + return true; // Aggregate all fields when none specified + } + return fieldsToAggregate.stream() + .anyMatch(field -> field.getField().toString().equals(fieldName)); + } + + /** Check if a RexNode represents a numeric field */ + private boolean isNumericField(RexNode rexNode, CalcitePlanContext context) { + return rexNode.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC; + } + + /** Check if a field name represents a numeric field in the RelNode */ + private boolean isNumericFieldName(String fieldName, RelNode relNode) { + try { + RelDataTypeField field = relNode.getRowType().getField(fieldName, false, false); + return field != null && field.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC; + } catch (Exception e) { + return false; + } + } + @Override public RelNode visitChart(Chart node, CalcitePlanContext context) { visitChildren(node, context); diff --git a/docs/category.json b/docs/category.json index 0ba23db151e..280dae44e9e 100644 --- a/docs/category.json +++ b/docs/category.json @@ -10,6 +10,8 @@ "ppl_cli_calcite": [ "user/ppl/cmd/ad.md", "user/ppl/cmd/append.md", + "user/ppl/cmd/addtotals.md", + "user/ppl/cmd/addcoltotals.md", "user/ppl/cmd/bin.md", "user/ppl/cmd/dedup.md", "user/ppl/cmd/describe.md", diff --git a/docs/user/ppl/cmd/addcoltotals.md b/docs/user/ppl/cmd/addcoltotals.md new file mode 100644 index 00000000000..bcc089859ec --- /dev/null +++ b/docs/user/ppl/cmd/addcoltotals.md @@ -0,0 +1,88 @@ +# AddColTotals + + +# Description + +The `addcoltotals` command computes the sum of each column and add a summary event at the end to show the total of each column. This command works the same way `addtotals` command works with row=false and col=true option. This is useful for creating summary reports with subtotals or grand totals. The `addcoltotals` command only sums numeric fields (integers, floats, doubles). Non-numeric fields in the field list are ignored even if its specified in field-list or in the case of no field-list specified. + +# Syntax + +`addcoltotals [field-list] [label=] [labelfield=]` + +- `field-list`: Optional. Comma-separated list of numeric fields to sum. If not specified, all numeric fields are summed. +- `labelfield=`: Optional. Field name to place the label. If it specifies a non-existing field, adds the field and shows label at the summary event row at this field. +- `label=`: Optional. Custom text for the totals row labelfield\'s label. Default is \"Total\". + +# Example 1: Basic Example + +The example shows placing the label in an existing field. + +```ppl +source=accounts +| fields firstname, balance +| head 3 +| addcoltotals labelfield='firstname' +``` + +Expected output: + +```text +fetched rows / total rows = 4/4 ++-----------+---------+ +| firstname | balance | +|-----------+---------| +| Amber | 39225 | +| Hattie | 5686 | +| Nanette | 32838 | +| Total | 77749 | ++-----------+---------+ +``` + +# Example 2: Adding column totals and adding a summary event with label specified. + +The example shows adding totals after a stats command where final summary event label is \'Sum\' and row=true value was used by default when not specified. It also added new field specified by labelfield as it did not match existing field. + +```ppl +source=accounts +| stats count() by gender +| addcoltotals `count()` label='Sum' labelfield='Total' +``` + +Expected output: + +```text +fetched rows / total rows = 3/3 ++---------+--------+-------+ +| count() | gender | Total | +|---------+--------+-------| +| 1 | F | null | +| 3 | M | null | +| 4 | null | Sum | ++---------+--------+-------+ +``` + +# Example 3: With all options + +The example shows using addcoltotals with all options set. + +```ppl +source=accounts +| where age > 30 +| stats avg(balance) as avg_balance, count() as count by state +| head 3 +| addcoltotals avg_balance, count label='Sum' labelfield='Column Total' +``` + +Expected output: + +```text +fetched rows / total rows = 4/4 ++-------------+-------+-------+--------------+ +| avg_balance | count | state | Column Total | +|-------------+-------+-------+--------------| +| 39225.0 | 1 | IL | null | +| 4180.0 | 1 | MD | null | +| 5686.0 | 1 | TN | null | +| 49091.0 | 3 | null | Sum | ++-------------+-------+-------+--------------+ +``` diff --git a/docs/user/ppl/cmd/addtotals.md b/docs/user/ppl/cmd/addtotals.md new file mode 100644 index 00000000000..745b1ae750f --- /dev/null +++ b/docs/user/ppl/cmd/addtotals.md @@ -0,0 +1,116 @@ +# AddTotals + + +## Description + +The `addtotals` command computes the sum of numeric fields and appends a row with the totals to the result. The command can also add row totals and add a field to store row totals. This is useful for creating summary reports with subtotals or grand totals. The `addtotals` command only sums numeric fields (integers, floats, doubles). Non-numeric fields in the field list are ignored even if it\'s specified in field-list or in the case of no field-list specified. + +## Syntax + +`addtotals [field-list] [label=] [labelfield=] [row=] [col=] [fieldname=]` + +- `field-list`: Optional. Comma-separated list of numeric fields to sum. If not specified, all numeric fields are summed. +- `row=`: Optional. Calculates total of each row and add a new field with the total. Default is true. +- `col=`: Optional. Calculates total of each column and add a new event at the end of all events with the total. Default is false. +- `labelfield=`: Optional. Field name to place the label. If it specifies a non-existing field, adds the field and shows label at the summary event row at this field. This is applicable when col=true. +- `label=`: Optional. Custom text for the totals row labelfield\'s label. Default is \"Total\". This is applicable when col=true. This does not have any effect when labelfield and fieldname parameter both have same value. +- `fieldname=`: Optional. Calculates total of each row and add a new field to store this total. This is applicable when row=true. + +## Example 1: Basic Example + +The example shows placing the label in an existing field. + +```ppl +source=accounts +| head 3 +|fields firstname, balance +| addtotals col=true labelfield='firstname' label='Total' +``` + +Expected output: + +```text +fetched rows / total rows = 4/4 ++-----------+---------+-------+ +| firstname | balance | Total | +|-----------+---------+-------| +| Amber | 39225 | 39225 | +| Hattie | 5686 | 5686 | +| Nanette | 32838 | 32838 | +| Total | 77749 | null | ++-----------+---------+-------+ +``` + +## Example 2: Adding column totals and adding a summary event with label specified. + +The example shows adding totals after a stats command where final summary event label is \'Sum\'. It also added new field specified by labelfield as it did not match existing field. + +```ppl +source=accounts +| fields account_number, firstname , balance , age +| addtotals col=true row=false label='Sum' labelfield='Total' +``` + +Expected output: + +```text +fetched rows / total rows = 5/5 ++----------------+-----------+---------+-----+-------+ +| account_number | firstname | balance | age | Total | +|----------------+-----------+---------+-----+-------| +| 1 | Amber | 39225 | 32 | null | +| 6 | Hattie | 5686 | 36 | null | +| 13 | Nanette | 32838 | 28 | null | +| 18 | Dale | 4180 | 33 | null | +| 38 | null | 81929 | 129 | Sum | ++----------------+-----------+---------+-----+-------+ +``` + +if row=true in above example, there will be conflict between column added for column totals and column added for row totals being same field \'Total\', in that case the output will have final event row label null instead of \'Sum\' because the column is number type and it cannot output String in number type column. + +```ppl +source=accounts +| fields account_number, firstname , balance , age +| addtotals col=true row=true label='Sum' labelfield='Total' +``` + +Expected output: + +```text +fetched rows / total rows = 5/5 ++----------------+-----------+---------+-----+-------+ +| account_number | firstname | balance | age | Total | +|----------------+-----------+---------+-----+-------| +| 1 | Amber | 39225 | 32 | 39258 | +| 6 | Hattie | 5686 | 36 | 5728 | +| 13 | Nanette | 32838 | 28 | 32879 | +| 18 | Dale | 4180 | 33 | 4231 | +| 38 | null | 81929 | 129 | null | ++----------------+-----------+---------+-----+-------+ +``` + +## Example 3: With all options + +The example shows using addtotals with all options set. + +```ppl +source=accounts +| where age > 30 +| stats avg(balance) as avg_balance, count() as count by state +| head 3 +| addtotals avg_balance, count row=true col=true fieldname='Row Total' label='Sum' labelfield='Column Total' +``` + +Expected output: + +```text +fetched rows / total rows = 4/4 ++-------------+-------+-------+-----------+--------------+ +| avg_balance | count | state | Row Total | Column Total | +|-------------+-------+-------+-----------+--------------| +| 39225.0 | 1 | IL | 39226.0 | null | +| 4180.0 | 1 | MD | 4181.0 | null | +| 5686.0 | 1 | TN | 5687.0 | null | +| 49091.0 | 3 | null | null | Sum | ++-------------+-------+-------+-----------+--------------+ +``` \ No newline at end of file diff --git a/docs/user/ppl/index.md b/docs/user/ppl/index.md index a8fcb5a480b..2525ed9b908 100644 --- a/docs/user/ppl/index.md +++ b/docs/user/ppl/index.md @@ -78,7 +78,9 @@ source=accounts | [describe command](cmd/describe.md) | 2.1 | stable (since 2.1) | Query the metadata of an index. | | [explain command](cmd/explain.md) | 3.1 | stable (since 3.1) | Explain the plan of query. | | [show datasources command](cmd/showdatasources.md) | 2.4 | stable (since 2.4) | Query datasources configured in the PPL engine. | - + | [addtotals command](cmd/addtotals.md) | 3.4 | stable (since 3.4) | Adds row and column values and appends a totals column and row. | + | [addcoltotals command](cmd/addcoltotals.md) | 3.4 | stable (since 3.4) | Adds column values and appends a totals row. | + - [Syntax](cmd/syntax.md) - PPL query structure and command syntax formatting * **Functions** - [Aggregation Functions](functions/aggregations.md) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java index 15051417db1..c254fb47c44 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java @@ -21,6 +21,8 @@ @RunWith(Suite.class) @Suite.SuiteClasses({ CalciteExplainIT.class, + CalciteAddTotalsCommandIT.class, + CalciteAddColTotalsCommandIT.class, CalciteArrayFunctionIT.class, CalciteBinCommandIT.class, CalciteConvertTZFunctionIT.class, diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddColTotalsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddColTotalsCommandIT.java new file mode 100644 index 00000000000..fecdd7479c5 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddColTotalsCommandIT.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.remote; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +/** + * Integration tests for PPL addcoltotals command with Calcite engine enabled. Tests column-wise + * total computation scenarios including field selection, custom labels, and interactions with other + * PPL commands. + */ +public class CalciteAddColTotalsCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + enableCalcite(); + loadIndex(Index.ACCOUNT); + loadIndex(Index.BANK); + } + + @Test + public void testAddColTotalsTotalWithTotalField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addcoltotals", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema(result, schema("age", "bigint"), schema("balance", "bigint")); + + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, null); + } + + @Test + public void testAddColTotalsRowWithSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addcoltotals balance", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema(result, schema("age", "bigint"), schema("balance", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, null); + } + + private static boolean isNumeric(String str) { + return str != null && str.matches("-?\\d+(\\.\\d+)?"); + } + + private void verifyColTotals( + org.json.JSONArray dataRows, List field_indexes, String finalSummaryEventLevel) { + + BigDecimal[] cColTotals = new BigDecimal[field_indexes.size()]; + for (int i = 0; i < dataRows.length() - 1; i++) { + var row = dataRows.getJSONArray(i); + + // Iterate through each field in the row + for (int j = 0; j < field_indexes.size(); j++) { + + int colIndex = field_indexes.get(j); + if (cColTotals[j] == null) { + cColTotals[j] = new BigDecimal(0); + } + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cColTotals[j] = cColTotals[j].add((BigDecimal) value); + + } else if (value instanceof String) { + if (org.opensearch.sql.calcite.remote.CalciteAddColTotalsCommandIT.isNumeric( + (String) value)) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((String) (value))); + } + } + } + } + var total_row = dataRows.getJSONArray((dataRows.length() - 1)); + for (int j = 0; j < field_indexes.size(); j++) { + int colIndex = field_indexes.get(j); + BigDecimal foundTotal = total_row.getBigDecimal(colIndex); + assertEquals(foundTotal.doubleValue(), cColTotals[j].doubleValue(), 0.000001); + } + if (finalSummaryEventLevel != null) { + String foundSummaryEventLabel = total_row.getString(total_row.length() - 1); + + assertEquals(foundSummaryEventLabel, finalSummaryEventLevel); + } + } + + @Test + public void testAddColTotalsRowFieldsNonNumeric() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |fields age address balance | addcoltotals ", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("address", "string"), schema("balance", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(2); + + verifyColTotals(dataRows, field_indexes, null); + } + + @Test + public void testAddColTotalsWithCustomLabel() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | head 2|fields age, balance | addcoltotals label='Sum'" + + " labelfield='Grand Total'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Grand Total", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } + + @Test + public void testAddColTotalsWithNoData() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 1000 | fields age, balance | addcoltotals", + TEST_INDEX_ACCOUNT)); + + // Should still have totals row even with no input data + var dataRows = result.getJSONArray("datarows"); + assertEquals(1, dataRows.length()); // Only totals row + } + + @Test + public void testAddColTotalsWithLabelAndLabelField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 3| fields age, balance,firstname | addcoltotals " + + " age balance label='Sum' labelfield='firstname'", + TEST_INDEX_ACCOUNT)); + + // Verify schema includes custom fieldname + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("firstname", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddTotalsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddTotalsCommandIT.java new file mode 100644 index 00000000000..6ce11142ddf --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteAddTotalsCommandIT.java @@ -0,0 +1,370 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.remote; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +public class CalciteAddTotalsCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + enableCalcite(); + loadIndex(Index.ACCOUNT); + loadIndex(Index.BANK); + } + + /** + * default test without parameters on account index + * + * @throws IOException + */ + @Test + public void testAddTotalsTotalWithTotalField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("balance", "bigint"), schema("Total", "bigint")); + + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < row.length() - 1; j++) { + Object value = row.isNull(j) ? 0 : row.get(j); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + @Test + public void testAddTotalsRowWithSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals balance", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("balance", "bigint"), schema("Total", "bigint")); + + // sum for balance, "Total" for label + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + + Object value = row.isNull(1) ? 0 : row.get(1); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + public static boolean isNumeric(String str) { + return str != null && str.matches("-?\\d+(\\.\\d+)?"); + } + + private void compareDataRowTotals( + org.json.JSONArray dataRows, List fieldIndexes, int totalColIndex) { + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < fieldIndexes.size(); j++) { + int colIndex = fieldIndexes.get(j); + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cRowTotal = cRowTotal.add((BigDecimal) value); + + } else if (value instanceof String) { + if (isNumeric((String) value)) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + } + BigDecimal foundTotal = row.getBigDecimal(totalColIndex); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + private void verifyColTotals( + org.json.JSONArray dataRows, List field_indexes, String finalSummaryEventLevel) { + + BigDecimal[] cColTotals = new BigDecimal[field_indexes.size()]; + for (int i = 0; i < dataRows.length() - 1; i++) { + var row = dataRows.getJSONArray(i); + + // Iterate through each field in the row + for (int j = 0; j < field_indexes.size(); j++) { + + int colIndex = field_indexes.get(j); + if (cColTotals[j] == null) { + cColTotals[j] = new BigDecimal(0); + } + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cColTotals[j] = cColTotals[j].add((BigDecimal) value); + + } else if (value instanceof String) { + if (isNumeric((String) value)) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((String) (value))); + } + } + } + } + var total_row = dataRows.getJSONArray((dataRows.length() - 1)); + for (int j = 0; j < field_indexes.size(); j++) { + int colIndex = field_indexes.get(j); + BigDecimal foundTotal = total_row.getBigDecimal(colIndex); + assertEquals(foundTotal.doubleValue(), cColTotals[j].doubleValue(), 0.000001); + } + String foundSummaryEventLabel = total_row.getString(total_row.length() - 1); + assertEquals(foundSummaryEventLabel, finalSummaryEventLevel); + } + + @Test + public void testAddTotalsRowFieldsNonNumeric() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |fields age address balance | addtotals ", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, + schema("age", "bigint"), + schema("address", "string"), + schema("balance", "bigint"), + schema("Total", "bigint")); + + // sum for balance, "Total" for label + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < row.length() - 1; j++) { + Object value = row.isNull(j) ? 0 : row.get(j); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + if (org.opensearch.sql.calcite.remote.CalciteAddTotalsCommandIT.isNumeric( + (String) value)) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + } + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + @Test + public void testAddTotalsWithCustomLabel() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | head 2|fields age, balance | addtotals" + + " fieldname='Grand Total'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Grand Total", "bigint")); + } + + @Test + public void testAddTotalsAfterStats() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | stats count() by gender | addtotals `count()`", TEST_INDEX_ACCOUNT)); + + var dataRows = result.getJSONArray("datarows"); + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + assertEquals(row.get(0), row.get(2)); + } + } + + @Test + public void testAddTotalsWithNoData() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 1000 | fields age, balance | addtotals", + TEST_INDEX_ACCOUNT)); + + // Should still have totals row even with no input data + var dataRows = result.getJSONArray("datarows"); + assertEquals(0, dataRows.length()); // Only totals row + } + + @Test + public void testAddTotalsInComplexPipeline() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | stats avg(balance) as avg_balance, count() as" + + " total_count by gender | addtotals avg_balance, total_count", + TEST_INDEX_ACCOUNT)); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + compareDataRowTotals(dataRows, field_indexes, 3); + } + + @Test + public void testAddTotalsWithRowFalse() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals row=false", + TEST_INDEX_ACCOUNT)); + + // With row=false, should not append totals row + var dataRows = result.getJSONArray("datarows"); + + // Verify that no totals row was added - all rows should have actual data + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + // None of these rows should have "Total" label + for (int j = 0; j < row.length(); j++) { + if (!row.isNull(j) && row.get(j).equals("Total")) { + fail("Found totals row when row=false was specified"); + } + } + } + } + + @Test + public void testAddTotalsWithLabelAndLabelField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 3| fields age, balance | addtotals row=false" + + " col=true label='Sum' labelfield='Total Summary'", + TEST_INDEX_ACCOUNT)); + + // Verify schema includes custom fieldname + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Total Summary", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } + + @Test + public void testAddTotalsWithFieldnameAndSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 2| fields age, balance | addtotals balance" + + " fieldname='BalanceSum'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("BalanceSum", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(1); + + compareDataRowTotals(dataRows, field_indexes, 2); + } + + @Test + public void testAddTotalsWithFieldnameNoRow() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | " + + "addtotals balance fieldname='CustomSum' row=false", + TEST_INDEX_ACCOUNT)); + + // With row=false, should not append totals row regardless of fieldname + var dataRows = result.getJSONArray("datarows"); + + // Verify that no totals row was added + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + // None of these rows should have "CustomSum" label + for (int j = 0; j < row.length(); j++) { + if (!row.isNull(j) && row.get(j).equals("CustomSum")) { + fail("Found totals row when row=false was specified"); + } + } + } + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 5db4d7f9e50..669eb4eeaba 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1997,6 +1997,30 @@ public void testInternalItemAccessOnStructs() throws IOException { } @Test + public void testaddTotalsExplain() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + String expected = loadExpectedPlan("explain_add_totals.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account" + + "| head 5 " + + "| addtotals balance age label='ColTotal' " + + " fieldname='CustomSum' labelfield='all_emp_total' row=true col=true")); + } + + @Test + public void testaddColTotalsExplain() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + String expected = loadExpectedPlan("explain_add_col_totals.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account" + + "| head 5 " + + "| addcoltotals balance age label='GrandTotal'")); + } + public void testComplexDedup() throws IOException { enabledOnlyWhenPushdownIsEnabled(); String expected = loadExpectedPlan("explain_dedup_complex1.yaml"); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java index 93e9af8e2a5..15f3c508b14 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java @@ -180,6 +180,28 @@ public void testStrftimeFunction() throws IOException { } } + @Test + public void testAddTotalCommand() throws IOException { + JSONObject result; + try { + executeQuery(String.format("search source=%s | addtotals ", TEST_INDEX_BANK)); + } catch (ResponseException e) { + result = new JSONObject(TestUtils.getResponseBody(e.getResponse())); + verifyQuery(result); + } + } + + @Test + public void testAddColTotalCommand() throws IOException { + JSONObject result; + try { + executeQuery(String.format("search source=%s | addcoltotals ", TEST_INDEX_BANK)); + } catch (ResponseException e) { + result = new JSONObject(TestUtils.getResponseBody(e.getResponse())); + verifyQuery(result); + } + } + private void verifyQuery(JSONObject result) throws IOException { if (isCalciteEnabled()) { assertFalse(result.getJSONArray("datarows").isEmpty()); diff --git a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java index a35351cf8e0..4c032bbb623 100644 --- a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java @@ -247,6 +247,43 @@ public void testCrossClusterQueryStringWithoutFields() throws IOException { verifyDataRows(result, rows("Hattie")); } + @Test + public void testCrossClusterAddTotals() throws IOException { + try { + enableCalcite(); + + // Test query_string without fields parameter on remote cluster + JSONObject result = + executeQuery( + String.format( + "search source=%s| sort 1 age | fields firstname, age | addtotals age", + TEST_INDEX_BANK_REMOTE)); + verifyDataRows(result, rows("Nanette", 28, 28)); + } finally { + disableCalcite(); + } + } + + /** CrossClusterSearchIT Test for addcoltotals. */ + @Test + public void testCrossClusterAddColTotals() throws IOException { + try { + enableCalcite(); + + // Test query_string without fields parameter on remote cluster + JSONObject result = + executeQuery( + String.format( + "search source=%s | where firstname='Hattie' or firstname ='Nanette'|fields" + + " firstname,age,balance | addcoltotals age balance", + TEST_INDEX_BANK_REMOTE)); + verifyDataRows( + result, rows("Hattie", 36, 5686), rows("Nanette", 28, 32838), rows(null, 64, 38524)); + } finally { + disableCalcite(); + } + } + @Test public void testCrossClusterAppend() throws IOException { // TODO: We should enable calcite by default in CrossClusterSearchIT? diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_add_col_totals.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_add_col_totals.yaml new file mode 100644 index 00000000000..0a8139b1eaa --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_add_col_totals.yaml @@ -0,0 +1,18 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10]) + LogicalUnion(all=[true]) + LogicalSort(fetch=[5]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + LogicalProject(account_number=[null:BIGINT], firstname=[null:VARCHAR], address=[null:VARCHAR], balance=[$0], gender=[null:VARCHAR], city=[null:VARCHAR], employer=[null:VARCHAR], state=[null:VARCHAR], age=[$1], email=[null:VARCHAR], lastname=[null:VARCHAR], _id=[null:VARCHAR], _index=[null:VARCHAR], _score=[null:REAL], _maxscore=[null:REAL], _sort=[null:BIGINT], _routing=[null:VARCHAR]) + LogicalAggregate(group=[{}], balance=[SUM($3)], age=[SUM($8)]) + LogicalSort(fetch=[5]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableUnion(all=[true]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname], LIMIT->5], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":5,"timeout":"1m","_source":{"includes":["account_number","firstname","address","balance","gender","city","employer","state","age","email","lastname"],"excludes":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=[null:BIGINT], expr#3=[null:VARCHAR], account_number=[$t2], firstname=[$t3], address=[$t3], balance=[$t0], gender=[$t3], city=[$t3], employer=[$t3], state=[$t3], age=[$t1], email=[$t3], lastname=[$t3]) + EnumerableAggregate(group=[{}], balance=[SUM($0)], age=[SUM($1)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[balance, age], LIMIT->5], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":5,"timeout":"1m","_source":{"includes":["balance","age"],"excludes":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_add_totals.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_add_totals.yaml new file mode 100644 index 00000000000..0c8b4ec26a2 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_add_totals.yaml @@ -0,0 +1,22 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], CustomSum=[$17], all_emp_total=[$18]) + LogicalUnion(all=[true]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], CustomSum=[+($3, $8)], all_emp_total=[null:VARCHAR(13)]) + LogicalSort(fetch=[5]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + LogicalProject(account_number=[null:BIGINT], firstname=[null:VARCHAR], address=[null:VARCHAR], balance=[$0], gender=[null:VARCHAR], city=[null:VARCHAR], employer=[null:VARCHAR], state=[null:VARCHAR], age=[$1], email=[null:VARCHAR], lastname=[null:VARCHAR], _id=[null:VARCHAR], _index=[null:VARCHAR], _score=[null:REAL], _maxscore=[null:REAL], _sort=[null:BIGINT], _routing=[null:VARCHAR], CustomSum=[null:BIGINT], all_emp_total=['ColTotal':VARCHAR(13)]) + LogicalAggregate(group=[{}], balance=[SUM($0)], age=[SUM($1)]) + LogicalProject(balance=[$3], age=[$8]) + LogicalSort(fetch=[5]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableUnion(all=[true]) + EnumerableCalc(expr#0..10=[{inputs}], expr#11=[+($t3, $t8)], expr#12=[null:VARCHAR(13)], proj#0..12=[{exprs}]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname], LIMIT->5, LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":5,"timeout":"1m","_source":{"includes":["account_number","firstname","address","balance","gender","city","employer","state","age","email","lastname"],"excludes":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)]) + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=[null:BIGINT], expr#3=[null:VARCHAR], expr#4=['ColTotal':VARCHAR(13)], account_number=[$t2], firstname=[$t3], address=[$t3], balance=[$t0], gender=[$t3], city=[$t3], employer=[$t3], state=[$t3], age=[$t1], email=[$t3], lastname=[$t3], CustomSum=[$t2], all_emp_total=[$t4]) + EnumerableAggregate(group=[{}], balance=[SUM($0)], age=[SUM($1)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[balance, age], LIMIT->5], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":5,"timeout":"1m","_source":{"includes":["balance","age"],"excludes":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)]) diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 694aabf43ab..ac72575e895 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -49,6 +49,10 @@ TRENDLINE: 'TRENDLINE'; CHART: 'CHART'; TIMECHART: 'TIMECHART'; APPENDCOL: 'APPENDCOL'; +ADDTOTALS: 'ADDTOTALS'; +ADDCOLTOTALS: 'ADDCOLTOTALS'; +ROW: 'ROW'; +COL: 'COL'; EXPAND: 'EXPAND'; SIMPLE_PATTERN: 'SIMPLE_PATTERN'; BRAIN: 'BRAIN'; @@ -59,6 +63,9 @@ MAX_SAMPLE_COUNT: 'MAX_SAMPLE_COUNT'; MAX_MATCH: 'MAX_MATCH'; OFFSET_FIELD: 'OFFSET_FIELD'; BUFFER_LIMIT: 'BUFFER_LIMIT'; +FIELDLIST: 'FIELDLIST'; +LABELFIELD: 'LABELFIELD'; +FIELDNAME: 'FIELDNAME'; LABEL: 'LABEL'; SHOW_NUMBERED_TOKEN: 'SHOW_NUMBERED_TOKEN'; AGGREGATION: 'AGGREGATION'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 1cc33cd7f5d..867e2f0d28a 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -76,6 +76,8 @@ commands | fillnullCommand | trendlineCommand | appendcolCommand + | addtotalsCommand + | addcoltotalsCommand | appendCommand | expandCommand | flattenCommand @@ -126,6 +128,8 @@ commandName | EXPLAIN | REVERSE | REGEX + | ADDTOTALS + | ADDCOLTOTALS | APPEND | MULTISEARCH | REX @@ -586,6 +590,29 @@ mlArg : (argName = ident EQUAL argValue = literalValue) ; +addtotalsCommand + : ADDTOTALS (fieldList)? addtotalsOption* + | ADDTOTALS addtotalsOption* (fieldList)? + ; + +addtotalsOption + : (LABEL EQUAL stringLiteral) + | (LABELFIELD EQUAL stringLiteral) + | (FIELDNAME EQUAL stringLiteral) + | (ROW EQUAL booleanLiteral) + | (COL EQUAL booleanLiteral) + ; + +addcoltotalsCommand + : ADDCOLTOTALS (fieldList)? addcoltotalsOption* + | ADDCOLTOTALS addcoltotalsOption* (fieldList)? + ; + +addcoltotalsOption + : (LABEL EQUAL stringLiteral) + | (LABELFIELD EQUAL stringLiteral) + ; + // clauses fromClause : SOURCE EQUAL tableOrSubqueryClause @@ -1659,4 +1686,12 @@ searchableKeyWord | LEFT_HINT | RIGHT_HINT | PERCENTILE_SHORTCUT + | ADDTOTALS + | ADDCOLTOTALS + | LABEL + | LABELFIELD + | FIELDNAME + | ROW + | COL ; + diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index bcdb30a9fc3..87e128373d5 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -68,6 +68,8 @@ import org.opensearch.sql.ast.expression.WindowFrame; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -1413,4 +1415,43 @@ private boolean hasActualWildcards(OpenSearchPPLParser.FieldsCommandBodyContext } return false; } + + @Override + public UnresolvedPlan visitAddtotalsCommand(OpenSearchPPLParser.AddtotalsCommandContext ctx) { + + List fieldList = new ArrayList<>(); + if (ctx.fieldList() != null) { + fieldList = getFieldList(ctx.fieldList()); + } + ImmutableMap.Builder cmdOptionsBuilder = ImmutableMap.builder(); + ctx.addtotalsOption() + .forEach( + option -> { + String argName = option.children.get(0).toString(); + Literal value = (Literal) internalVisitExpression(option.children.get(2)); + cmdOptionsBuilder.put(argName, value); + }); + java.util.Map options = cmdOptionsBuilder.build(); + return new AddTotals(fieldList, options); + } + + @Override + public UnresolvedPlan visitAddcoltotalsCommand( + OpenSearchPPLParser.AddcoltotalsCommandContext ctx) { + + List fieldList = new ArrayList<>(); + if (ctx.fieldList() != null) { + fieldList = getFieldList(ctx.fieldList()); + } + ImmutableMap.Builder cmdOptionsBuilder = ImmutableMap.builder(); + ctx.addcoltotalsOption() + .forEach( + option -> { + String argName = option.children.get(0).toString(); + Literal value = (Literal) internalVisitExpression(option.children.get(2)); + cmdOptionsBuilder.put(argName, value); + }); + java.util.Map options = cmdOptionsBuilder.build(); + return new AddColTotals(fieldList, options); + } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 7e9043ecf36..160dc9e34b5 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -59,6 +59,8 @@ import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -789,6 +791,41 @@ public String visitSpath(SPath node, String context) { return builder.toString(); } + public void appendAddTotalsOptionParameters( + List fieldList, java.util.Map options, StringBuilder builder) { + + if (!fieldList.isEmpty()) { + builder.append(visitExpressionList(fieldList, " ")); + } + if (!options.isEmpty()) { + for (String key : options.keySet()) { + String value = options.get(key).toString(); + if (value.matches(".*\\s.*")) { + value = StringUtils.format("'%s'", value); + } + builder.append(" ").append(key).append("=").append(value); + } + } + } + + @Override + public String visitAddTotals(AddTotals node, String context) { + String child = node.getChild().get(0).accept(this, context); + StringBuilder builder = new StringBuilder(); + builder.append(child).append(" | addtotals"); + appendAddTotalsOptionParameters(node.getFieldList(), node.getOptions(), builder); + return builder.toString(); + } + + @Override + public String visitAddColTotals(AddColTotals node, String context) { + String child = node.getChild().get(0).accept(this, context); + StringBuilder builder = new StringBuilder(); + builder.append(child).append(" | addcoltotals"); + appendAddTotalsOptionParameters(node.getFieldList(), node.getOptions(), builder); + return builder.toString(); + } + @Override public String visitPatterns(Patterns node, String context) { String child = node.getChild().get(0).accept(this, context); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java index c9a62d55dca..6d89bd352b6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java @@ -887,4 +887,36 @@ public void testWhereCommandWithDoubleEqual() { "SOURCE=test | WHERE query_string(['field1', 'field2' ^ 3.2], 'test query'," + " analyzer='keyword')")); } + + @Test + public void testAddTotalsCommandShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithFieldsShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals price, quantity"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithLabelShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals label='Grand Total'"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithLabelFieldShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals labelfield='category'"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithAllOptionsShouldPass() { + ParseTree tree = + new PPLSyntaxParser() + .parse("source=t | addtotals price, quantity label='Total' labelfield='type'"); + assertNotEquals(null, tree); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java new file mode 100644 index 00000000000..cd9fb985764 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java @@ -0,0 +1,377 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import java.io.IOException; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLAddColTotalsTest extends CalcitePPLAbstractTest { + + public CalcitePPLAddColTotalsTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testAddColTotals() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsFieldSpecified() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING)" + + " `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsAllFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsMultiFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals DEPTNO SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsWithAllOptions() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL label='GrandTotal'" + + " labelfield='all_emp_total' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2]," + + " all_emp_total=[null:VARCHAR(13)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " all_emp_total=['GrandTotal':VARCHAR(13)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; all_emp_total=null\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; all_emp_total=null\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; all_emp_total=null\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; all_emp_total=GrandTotal\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, CAST(NULL AS STRING) `all_emp_total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " 'GrandTotal' `all_emp_total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsMatchingLabelFieldWithExisting() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL label='GrandTotal'" + + " labelfield='JOB' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=['GrandTota':VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=GrandTota\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, 'GrandTota' `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsMatchingLabelFieldWithExistingChangedOrder() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals label='GrandTotal'" + + " labelfield='JOB' SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=['GrandTota':VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=GrandTota\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, 'GrandTota' `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsAllFieldsWithLabel() throws IOException { + String ppl = "source=EMP | addcoltotals label='GrandTotal' " + " labelfield='JOB' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[null:VARCHAR(10)], JOB=['GrandTota':VARCHAR(9)]," + + " MGR=[$1], HIREDATE=[null:DATE], SAL=[$2], COMM=[$3], DEPTNO=[$4])\n" + + " LogicalAggregate(group=[{}], EMPNO=[SUM($0)], MGR=[SUM($3)], SAL=[SUM($5)]," + + " COMM=[SUM($6)], DEPTNO=[SUM($7)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + + verifyLogical(root, expectedLogical); + String expectedResult = + "EMPNO=7369; ENAME=SMITH; JOB=CLERK; MGR=7902; HIREDATE=1980-12-17; SAL=800.00; COMM=null;" + + " DEPTNO=20\n" + + "EMPNO=7499; ENAME=ALLEN; JOB=SALESMAN; MGR=7698; HIREDATE=1981-02-20; SAL=1600.00;" + + " COMM=300.00; DEPTNO=30\n" + + "EMPNO=7521; ENAME=WARD; JOB=SALESMAN; MGR=7698; HIREDATE=1981-02-22; SAL=1250.00;" + + " COMM=500.00; DEPTNO=30\n" + + "EMPNO=7566; ENAME=JONES; JOB=MANAGER; MGR=7839; HIREDATE=1981-02-04; SAL=2975.00;" + + " COMM=null; DEPTNO=20\n" + + "EMPNO=7654; ENAME=MARTIN; JOB=SALESMAN; MGR=7698; HIREDATE=1981-09-28; SAL=1250.00;" + + " COMM=1400.00; DEPTNO=30\n" + + "EMPNO=7698; ENAME=BLAKE; JOB=MANAGER; MGR=7839; HIREDATE=1981-01-05; SAL=2850.00;" + + " COMM=null; DEPTNO=30\n" + + "EMPNO=7782; ENAME=CLARK; JOB=MANAGER; MGR=7839; HIREDATE=1981-06-09; SAL=2450.00;" + + " COMM=null; DEPTNO=10\n" + + "EMPNO=7788; ENAME=SCOTT; JOB=ANALYST; MGR=7566; HIREDATE=1987-04-19; SAL=3000.00;" + + " COMM=null; DEPTNO=20\n" + + "EMPNO=7839; ENAME=KING; JOB=PRESIDENT; MGR=null; HIREDATE=1981-11-17; SAL=5000.00;" + + " COMM=null; DEPTNO=10\n" + + "EMPNO=7844; ENAME=TURNER; JOB=SALESMAN; MGR=7698; HIREDATE=1981-09-08; SAL=1500.00;" + + " COMM=0.00; DEPTNO=30\n" + + "EMPNO=7876; ENAME=ADAMS; JOB=CLERK; MGR=7788; HIREDATE=1987-05-23; SAL=1100.00;" + + " COMM=null; DEPTNO=20\n" + + "EMPNO=7900; ENAME=JAMES; JOB=CLERK; MGR=7698; HIREDATE=1981-12-03; SAL=950.00;" + + " COMM=null; DEPTNO=30\n" + + "EMPNO=7902; ENAME=FORD; JOB=ANALYST; MGR=7566; HIREDATE=1981-12-03; SAL=3000.00;" + + " COMM=null; DEPTNO=20\n" + + "EMPNO=7934; ENAME=MILLER; JOB=CLERK; MGR=7782; HIREDATE=1982-01-23; SAL=1300.00;" + + " COMM=null; DEPTNO=10\n" + + "EMPNO=108172; ENAME=null; JOB=GrandTota; MGR=100611; HIREDATE=null; SAL=29025.00;" + + " COMM=2200.00; DEPTNO=310\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`EMPNO`) `EMPNO`, CAST(NULL AS STRING) `ENAME`, 'GrandTota' `JOB`," + + " SUM(`MGR`) `MGR`, CAST(NULL AS DATE) `HIREDATE`, SUM(`SAL`) `SAL`, SUM(`COMM`)" + + " `COMM`, SUM(`DEPTNO`) `DEPTNO`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java new file mode 100644 index 00000000000..a0454f0917c --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java @@ -0,0 +1,472 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import java.io.IOException; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLAddTotalsTest extends CalcitePPLAbstractTest { + + public CalcitePPLAddTotalsTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testAddTotals() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsAllFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsMultiFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals DEPTNO SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithFieldname() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithFieldnameRowOptionTrue() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' row=true "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithFieldnameRowOptionFalse() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' row=false "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = "SELECT `DEPTNO`, `SAL`, `JOB`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithColTrueNoSummaryLabel() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " Total=[null:DECIMAL(7, 2)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; Total=null\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithColTrueRowFalseNoSummaryLabel() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL col=true row=false"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING)" + + " `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithAllOptionsIncludingDefaultFieldname() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL label='ColTotal'" + + " labelfield='Total' col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " Total=[null:DECIMAL(7, 2)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; Total=null\n"; + // by default row=true , new field added as 'Total' and labelfield='Total' will have conflict + // and 'ColTotal' will not be set in Total column as it will be number type being row=true + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithAllOptionsIncludingFieldname() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL label='ColTotal'" + + " fieldname='CustomSum' labelfield='all_emp_total' row=true col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5]," + + " all_emp_total=[null:NULL])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:NULL], SAL=[$0], JOB=[null:NULL]," + + " CustomSum=[null:NULL], all_emp_total=['ColTotal'])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + // verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00; all_emp_total=null\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; CustomSum=null; all_emp_total=ColTotal\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`, CAST(NULL AS STRING) `all_emp_total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `CustomSum`, 'ColTotal' `all_emp_total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsMatchingLabelFieldWithExisting() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL DEPTNO col=true label='ColTotal'" + + " labelfield='JOB' "; + // default is row=true for addtotals + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=['ColTotal':VARCHAR(9)]," + + " Total=[null:DECIMAL(8, 2)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n" + + "DEPTNO=310; SAL=29025.00; JOB=ColTotal; Total=null\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, 'ColTotal' `JOB`, CAST(NULL AS" + + " DECIMAL(8, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsMatchingLabelFieldWithExistingChangedOrder() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals col=true label='ColTotal'" + + " labelfield='JOB' SAL DEPTNO "; + // default is row=true for addtotals + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=['ColTotal':VARCHAR(9)]," + + " Total=[null:DECIMAL(8, 2)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n" + + "DEPTNO=310; SAL=29025.00; JOB=ColTotal; Total=null\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, 'ColTotal' `JOB`, CAST(NULL AS" + + " DECIMAL(8, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index ec166c81b7e..053ce5dabbb 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -502,6 +502,23 @@ public void testAppendcol() { anonymize("source=t | appendcol override=false [ where a = 1 ]")); } + @Test + public void testAddTotals() { + assertEquals( + "source=table | addtotals row=true col=true label=identifier labelfield=identifier" + + " fieldname=identifier", + anonymize( + "source=table | addtotals row=true col=true label='identifier' labelfield='identifier'" + + " fieldname='identifier'")); + } + + @Test + public void testAddColTotals() { + assertEquals( + "source=table | addcoltotals label=identifier labelfield=identifier", + anonymize("source=table | addcoltotals label='identifier' labelfield='identifier'")); + } + @Test public void testAppend() { assertEquals(