diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ListAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ListAggFunction.java new file mode 100644 index 00000000000..709df157e2d --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ListAggFunction.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.udf.udaf; + +import java.util.ArrayList; +import java.util.List; +import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; + +/** + * List aggregation function that collects values into an array preserving duplicates. + * + *

Behavior: + * + *

+ * + *

Note: Similar to the TAKE function, LIST does not guarantee any specific order of values in + * the result array. The order may vary between executions and depends on the underlying query + * execution plan and optimizations. + */ +public class ListAggFunction implements UserDefinedAggFunction { + + private static final int DEFAULT_LIMIT = 100; + + @Override + public ListAccumulator init() { + return new ListAccumulator(); + } + + @Override + public Object result(ListAccumulator accumulator) { + return accumulator.value(); + } + + @Override + public ListAccumulator add(ListAccumulator acc, Object... values) { + // Handle case where no values are passed + if (values == null || values.length == 0) { + return acc; + } + + Object value = values[0]; + + // Filter out null values and enforce 100-item limit + if (value != null && acc.size() < DEFAULT_LIMIT) { + // Convert value to string, handling all types safely + String stringValue = String.valueOf(value); + acc.add(stringValue); + } + + return acc; + } + + public static class ListAccumulator implements Accumulator { + private final List values; + + public ListAccumulator() { + this.values = new ArrayList<>(); + } + + @Override + public Object value(Object... argList) { + return values; + } + + public void add(String value) { + values.add(value); + } + + public int size() { + return values.size(); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java index e4b58c38662..faa51825edd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java @@ -169,4 +169,31 @@ private PPLOperandTypes() {} SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); + + /** + * Operand type checker that accepts any scalar type. This includes numeric types, strings, + * booleans, datetime types, and special scalar types like IP and BINARY. Excludes complex types + * like arrays, structs, and maps. + */ + public static final UDFOperandMetadata ANY_SCALAR = + UDFOperandMetadata.wrapUDT( + java.util.List.of( + // Numeric types + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.BYTE), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.SHORT), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.INTEGER), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.LONG), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.FLOAT), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.DOUBLE), + // String type + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.STRING), + // Boolean type + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.BOOLEAN), + // Temporal types + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.DATE), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.TIME), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP), + // Special scalar types + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.IP), + java.util.List.of(org.opensearch.sql.data.type.ExprCoreType.BINARY))); } diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java index bb0ea0831c5..4a115b706b3 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java @@ -55,4 +55,12 @@ private PPLReturnTypes() {} RelDataType firstArgType = argTypes.get(0); return SqlTypeUtil.createArrayType(typeFactory, firstArgType, true); }; + public static final SqlReturnTypeInference STRING_ARRAY = + opBinding -> { + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + // Always return array of strings since multivalue functions convert everything to strings + RelDataType stringType = + typeFactory.createSqlType(org.apache.calcite.sql.type.SqlTypeName.VARCHAR); + return SqlTypeUtil.createArrayType(typeFactory, stringType, true); + }; } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index db7919e4de2..645ee5a2f1e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -201,6 +201,9 @@ public enum BuiltinFunctionName { EARLIEST(FunctionName.of("earliest")), LATEST(FunctionName.of("latest")), DISTINCT_COUNT_APPROX(FunctionName.of("distinct_count_approx")), + + // Multivalue aggregation function + LIST(FunctionName.of("list")), // Not always an aggregation query NESTED(FunctionName.of("nested")), @@ -347,6 +350,7 @@ public enum BuiltinFunctionName { .put("earliest", BuiltinFunctionName.EARLIEST) .put("latest", BuiltinFunctionName.LATEST) .put("distinct_count_approx", BuiltinFunctionName.DISTINCT_COUNT_APPROX) + .put("list", BuiltinFunctionName.LIST) .put("pattern", BuiltinFunctionName.INTERNAL_PATTERN) .build(); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index a4acf68ffef..e69e451d7f8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -29,6 +29,7 @@ import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; +import org.opensearch.sql.calcite.udf.udaf.ListAggFunction; import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction; import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction; import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction; @@ -432,6 +433,9 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "pattern", ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList), null); + public static final SqlAggFunction LIST = + createUserDefinedAggFunction( + ListAggFunction.class, "LIST", PPLReturnTypes.STRING_ARRAY, PPLOperandTypes.ANY_SCALAR); /** * Returns the PPL specific operator table, creating it if necessary. diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 1a5c4f1beca..beb70e41fd7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -105,6 +105,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LESS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIST; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIMESTAMP; @@ -1088,6 +1089,7 @@ void populate() { registerOperator(STDDEV_POP, PPLBuiltinOperators.STDDEV_POP_NULLABLE); registerOperator(TAKE, PPLBuiltinOperators.TAKE); registerOperator(INTERNAL_PATTERN, PPLBuiltinOperators.INTERNAL_PATTERN); + registerOperator(LIST, PPLBuiltinOperators.LIST); register( AVG, diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 7aaab608358..001164860f4 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -28,6 +28,8 @@ The following table dataSources the aggregation functions and also indicates how +----------+-------------+-------------+ | MIN | Ignore | Ignore | +----------+-------------+-------------+ +| LIST | Ignore | Ignore | ++----------+-------------+-------------+ Syntax @@ -406,6 +408,41 @@ Example with custom time field:: | inactive | users | +----------------------------+----------+ +LIST +---- + +Description +>>>>>>>>>>> + +======= +Version: 3.3.0 (Calcite engine only) + +Usage: LIST(expr). Collects all values from the specified expression into an array. Values are converted to strings, nulls are filtered, and duplicates are preserved. +The function returns up to 100 values with no guaranteed ordering. + +* expr: The field expression to collect values from. +* This aggregation function doesn't support Array, Struct, Object field types. + +Example with string fields:: + + PPL> source=accounts | stats list(firstname); + fetched rows / total rows = 1/1 + +-------------------------------------+ + | list(firstname) | + |-------------------------------------|` + | ["Amber","Hattie","Nanette","Dale"] | + +-------------------------------------+ + +Example with result field rename:: + + PPL> source=accounts | stats list(firstname) as names; + fetched rows / total rows = 1/1 + +-------------------------------------+ + | names | + |-------------------------------------| + | ["Amber","Hattie","Nanette","Dale"] | + +-------------------------------------+ + Example 1: Calculate the count of events ======================================== @@ -628,3 +665,18 @@ PPL query:: | 28 | 20 | F | | 36 | 30 | M | +-----+----------+--------+ + +Example 14: Collect all values in a field using LIST +===================================================== + +The example shows how to collect all firstname values, preserving duplicates and order. + +PPL query:: + + PPL> source=accounts | stats list(firstname); + fetched rows / total rows = 1/1 + +-------------------------------------+ + | list(firstname) | + |-------------------------------------| + | ["Amber","Hattie","Nanette","Dale"] | + +-------------------------------------+ diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java index d4d9f61f42e..dc584a23a94 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java @@ -46,6 +46,7 @@ CalciteLegacyAPICompatibilityIT.class, CalciteLikeQueryIT.class, CalciteMathematicalFunctionIT.class, + CalciteMultiValueStatsIT.class, CalciteNewAddedCommandsIT.class, CalciteNowLikeFunctionIT.class, CalciteObjectFieldOperateIT.class, diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 9c1e1f22c1a..eacdd04615c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -382,6 +382,15 @@ public void testExplainOnEarliestLatestWithCustomTimeField() throws IOException TEST_INDEX_LOGS))); } + @Test + public void testListAggregationExplain() throws IOException { + String expected = loadExpectedPlan("explain_list_aggregation.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account | stats list(age) as age_list")); + } + /** * Executes the PPL query and returns the result as a string with windows-style line breaks * replaced with Unix-style ones. diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteMultiValueStatsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteMultiValueStatsIT.java new file mode 100644 index 00000000000..e8bebaf291c --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteMultiValueStatsIT.java @@ -0,0 +1,277 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.remote; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DATATYPE_NONNUMERIC; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DATATYPE_NUMERIC; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.util.List; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +public class CalciteMultiValueStatsIT extends PPLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + enableCalcite(); + loadIndex(Index.DATA_TYPE_NUMERIC); + loadIndex(Index.DATA_TYPE_NONNUMERIC); + loadIndex(Index.CALCS); + } + + // ==================== Positive Tests - All Supported Data Types ==================== + + @Test + public void testListFunctionWithBoolean() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(boolean_value) as bool_list", + TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("bool_list", "array")); + verifyDataRows(response, rows(List.of("true"))); + } + + @Test + public void testListFunctionWithByte() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(byte_number) as byte_list", TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("byte_list", "array")); + verifyDataRows(response, rows(List.of("4"))); + } + + @Test + public void testListFunctionWithShort() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(short_number) as short_list", TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("short_list", "array")); + verifyDataRows(response, rows(List.of("3"))); + } + + @Test + public void testListFunctionWithInteger() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(integer_number) as int_list", TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("int_list", "array")); + verifyDataRows(response, rows(List.of("2"))); + } + + @Test + public void testListFunctionWithLong() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(long_number) as long_list", TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("long_list", "array")); + verifyDataRows(response, rows(List.of("1"))); + } + + @Test + public void testListFunctionWithFloat() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(float_number) as float_list", TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("float_list", "array")); + verifyDataRows(response, rows(List.of("6.2"))); + } + + @Test + public void testListFunctionWithDouble() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(double_number) as double_list", + TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema(response, schema("double_list", "array")); + verifyDataRows(response, rows(List.of("5.1"))); + } + + @Test + public void testListFunctionWithString() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(keyword_value) as keyword_list", + TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("keyword_list", "array")); + verifyDataRows(response, rows(List.of("keyword"))); + + JSONObject textResponse = + executeQuery( + String.format( + "source=%s | stats list(text_value) as text_list", TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(textResponse, schema("text_list", "array")); + verifyDataRows(textResponse, rows(List.of("text"))); + } + + @Test + public void testListFunctionWithDate() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(date_value) as date_list", TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("date_list", "array")); + // Date values should be returned as timestamp strings + verifyDataRows(response, rows(List.of("2020-10-13 13:00:00"))); + } + + @Test + public void testListFunctionWithTime() throws IOException { + JSONObject response = + executeQuery( + String.format("source=%s | head 1 | stats list(time1) as time_list", TEST_INDEX_CALCS)); + verifySchema(response, schema("time_list", "array")); + // Time values are stored as strings in the test data + verifyDataRows(response, rows(List.of("19:36:22"))); + } + + @Test + public void testListFunctionWithTimestamp() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(date_nanos_value) as timestamp_list", + TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("timestamp_list", "array")); + // Calcite converts timezone to UTC + verifyDataRows(response, rows(List.of("2019-03-24 01:34:46.123456789"))); + } + + @Test + public void testListFunctionWithIP() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(ip_value) as ip_list", TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("ip_list", "array")); + verifyDataRows(response, rows(List.of("127.0.0.1"))); + } + + @Test + public void testListFunctionWithBinary() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(binary_value) as binary_list", + TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("binary_list", "array")); + verifyDataRows(response, rows(List.of("U29tZSBiaW5hcnkgYmxvYg=="))); + } + + // ==================== Edge Cases and Complex Scenarios ==================== + + @Test + public void testListFunctionWithNullValues() throws IOException { + JSONObject response = + executeQuery( + String.format("source=%s | head 5 | stats list(int0) as int_list", TEST_INDEX_CALCS)); + verifySchema(response, schema("int_list", "array")); + // Nulls are filtered out by list function + verifyDataRows(response, rows(List.of("1", "7"))); + } + + @Test + public void testListFunctionGroupBy() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | head 5 | stats list(num0) as num_list by str0", TEST_INDEX_CALCS)); + verifySchema(response, schema("num_list", "array"), schema("str0", null, "string")); + + // Group by str0 field - should have different groups with their respective num0 values + // Just verify we get some results with the correct schema + assert response.has("datarows"); + } + + @Test + public void testListFunctionMultipleFields() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | head 3 | stats list(str2) as str_list, list(int2) as int_list", + TEST_INDEX_CALCS)); + verifySchema(response, schema("str_list", "array"), schema("int_list", "array")); + + // Verify we get arrays for both fields + assert response.has("datarows"); + // Values should be collected from the first 3 rows (str2 and int2 columns) + // The actual values depend on the test data - int2 column contains 5, -4, 5 + verifyDataRows(response, rows(List.of("one", "two", "three"), List.of("5", "-4", "5"))); + } + + @Test + public void testListFunctionWithComplexGroupBy() throws IOException { + // Test list aggregation with multiple grouping fields + JSONObject response = + executeQuery( + String.format( + "source=%s | head 5 | stats list(num0) as values by str0, bool0", + TEST_INDEX_CALCS)); + verifySchema( + response, + schema("values", "array"), + schema("str0", null, "string"), + schema("bool0", null, "boolean")); + + // Should have multiple groups based on str0 and bool0 combinations + assert response.has("datarows"); + // Verify we get grouped results with proper values + assert response.getJSONArray("datarows").length() > 0; + } + + @Test + public void testListFunctionEmptyResult() throws IOException { + // Test list function with no matching records - simplify this test + JSONObject response = + executeQuery( + String.format( + "source=%s | where str0 = 'NONEXISTENT' | stats list(num0) as empty_list", + TEST_INDEX_CALCS)); + verifySchema(response, schema("empty_list", "array")); + + assert response.has("datarows"); + // When no records match, LIST returns null (not an empty list) + verifyDataRows(response, rows((List) null)); + } + + // ==================== Advanced Functionality Tests ==================== + + @Test + public void testListFunctionWithObjectField() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats list(object_value.first) as object_field_list", + TEST_INDEX_DATATYPE_NONNUMERIC)); + verifySchema(response, schema("object_field_list", "array")); + verifyDataRows(response, rows(List.of("Dale"))); + } + + @Test + public void testListFunctionWithArithmeticExpression() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | head 3 | stats list(int3 + 1) as arithmetic_list", TEST_INDEX_CALCS)); + verifySchema(response, schema("arithmetic_list", "array")); + verifyDataRows(response, rows(List.of("9", "14", "3"))); + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_list_aggregation.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_list_aggregation.json new file mode 100644 index 00000000000..1b824c29814 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_list_aggregation.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], age_list=[LIST($0)])\n LogicalProject(age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableAggregate(group=[{}], age_list=[LIST($0)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"age\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_list_aggregation.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_list_aggregation.json new file mode 100644 index 00000000000..7da65bb8d4f --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_list_aggregation.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], age_list=[LIST($0)])\n LogicalProject(age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableAggregate(group=[{}], age_list=[LIST($8)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} \ No newline at end of file diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 063d9bf5602..66621a67590 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -494,6 +494,7 @@ statsFunctionName | STDDEV_POP | PERCENTILE | PERCENTILE_APPROX + | LIST ; earliestLatestFunction diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java index e09b62b748a..d1acdd168b4 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEvalTest.java @@ -421,4 +421,52 @@ public void testDependedLateralEval() { + "GROUP BY `DEPTNO`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + @Test + public void testListAggregationWithOtherAgg() { + String ppl = "source=EMP | stats list(DEPTNO), avg(DEPTNO)"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalAggregate(group=[{}], list(DEPTNO)=[LIST($0)], avg(DEPTNO)=[AVG($0)])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `LIST`(`DEPTNO`) `list(DEPTNO)`, AVG(`DEPTNO`) `avg(DEPTNO)`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testListAggregationAlone() { + String ppl = "source=EMP | stats list(DEPTNO)"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalAggregate(group=[{}], list(DEPTNO)=[LIST($0)])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = "SELECT `LIST`(`DEPTNO`) `list(DEPTNO)`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testListAggregationWithGroupBy() { + String ppl = "source=EMP | stats list(ENAME) by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(list(ENAME)=[$1], DEPTNO=[$0])\n" + + " LogicalAggregate(group=[{0}], list(ENAME)=[LIST($1)])\n" + + " LogicalProject(DEPTNO=[$7], ENAME=[$1])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `LIST`(`ENAME`) `list(ENAME)`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index 57d7df6393f..907abf76a89 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -294,4 +294,18 @@ public void testPercentileApproxWithWrongArgType() { + " {[INTEGER,INTEGER],[INTEGER,DOUBLE],[DOUBLE,INTEGER],[DOUBLE,DOUBLE],[INTEGER,INTEGER,INTEGER],[INTEGER,INTEGER,DOUBLE],[INTEGER,DOUBLE,INTEGER],[INTEGER,DOUBLE,DOUBLE],[DOUBLE,INTEGER,INTEGER],[DOUBLE,INTEGER,DOUBLE],[DOUBLE,DOUBLE,INTEGER],[DOUBLE,DOUBLE,DOUBLE]}," + " but got [STRING,INTEGER]"); } + + @Test + public void testListFunctionWithArrayArgType() { + // Test LIST function with array expression (which is not a supported scalar type) + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats list(array(ENAME, JOB)) as name_list")); + verifyErrorMessageContains( + e, + "Aggregation function LIST expects field type" + + " {[BYTE],[SHORT],[INTEGER],[LONG],[FLOAT],[DOUBLE],[STRING],[BOOLEAN],[DATE],[TIME],[TIMESTAMP],[IP],[BINARY]}," + + " but got [ARRAY]"); + } }