Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"calcite": {
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(earliest_message=[$1], latest_message=[$2], server=[$0])\n LogicalAggregate(group=[{0}], earliest_message=[ARG_MIN($1, $2)], latest_message=[ARG_MAX($1, $2)])\n LogicalProject(server=[$1], message=[$3], @timestamp=[$2])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]])\n",
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], earliest_message=[$t1], latest_message=[$t2], server=[$t0])\n EnumerableAggregate(group=[{0}], earliest_message=[ARG_MIN($1, $2)], latest_message=[ARG_MAX($1, $2)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]], PushDownContext=[[PROJECT->[server, message, @timestamp]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"server\",\"message\",\"@timestamp\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], earliest_message=[$t1], latest_message=[$t2], server=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},earliest_message=ARG_MIN($1, $2),latest_message=ARG_MAX($1, $2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"server\":{\"terms\":{\"field\":\"server\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"earliest_message\":{\"top_hits\":{\"from\":0,\"size\":1,\"version\":false,\"seq_no_primary_term\":false,\"explain\":false,\"_source\":{\"includes\":[\"message\"],\"excludes\":[]},\"sort\":[{\"@timestamp\":{\"order\":\"asc\"}}]}},\"latest_message\":{\"top_hits\":{\"from\":0,\"size\":1,\"version\":false,\"seq_no_primary_term\":false,\"explain\":false,\"_source\":{\"includes\":[\"message\"],\"excludes\":[]},\"sort\":[{\"@timestamp\":{\"order\":\"desc\"}}]}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"calcite": {
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(earliest_message=[$1], latest_message=[$2], level=[$0])\n LogicalAggregate(group=[{0}], earliest_message=[ARG_MIN($1, $2)], latest_message=[ARG_MAX($1, $2)])\n LogicalProject(level=[$4], message=[$3], created_at=[$0])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]])\n",
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], earliest_message=[$t1], latest_message=[$t2], level=[$t0])\n EnumerableAggregate(group=[{0}], earliest_message=[ARG_MIN($1, $2)], latest_message=[ARG_MAX($1, $2)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]], PushDownContext=[[PROJECT->[level, message, created_at]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"level\",\"message\",\"created_at\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], earliest_message=[$t1], latest_message=[$t2], level=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},earliest_message=ARG_MIN($1, $2),latest_message=ARG_MAX($1, $2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"level\":{\"terms\":{\"field\":\"level\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"earliest_message\":{\"top_hits\":{\"from\":0,\"size\":1,\"version\":false,\"seq_no_primary_term\":false,\"explain\":false,\"_source\":{\"includes\":[\"message\"],\"excludes\":[]},\"sort\":[{\"created_at\":{\"order\":\"asc\"}}]}},\"latest_message\":{\"top_hits\":{\"from\":0,\"size\":1,\"version\":false,\"seq_no_primary_term\":false,\"explain\":false,\"_source\":{\"includes\":[\"message\"],\"excludes\":[]},\"sort\":[{\"created_at\":{\"order\":\"desc\"}}]}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.opensearch.request.PredicateAnalyzer.NamedFieldExpression;
import org.opensearch.sql.opensearch.response.agg.ArgMaxMinParser;
import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser;
import org.opensearch.sql.opensearch.response.agg.MetricParser;
import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser;
Expand Down Expand Up @@ -305,6 +306,26 @@ private static Pair<AggregationBuilder, MetricParser> createRegularAggregation(
return Pair.of(
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
new StatsParser(ExtendedStats::getStdDeviationPopulation, aggFieldName));
case ARG_MAX:
return Pair.of(
AggregationBuilders.topHits(aggFieldName)
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
.size(1)
.from(0)
.sort(
helper.inferNamedField(args.get(1)).getRootName(),
org.opensearch.search.sort.SortOrder.DESC),
new ArgMaxMinParser(aggFieldName));
case ARG_MIN:
return Pair.of(
AggregationBuilders.topHits(aggFieldName)
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
.size(1)
.from(0)
.sort(
helper.inferNamedField(args.get(1)).getRootName(),
org.opensearch.search.sort.SortOrder.ASC),
new ArgMaxMinParser(aggFieldName));
case OTHER_FUNCTION:
BuiltinFunctionName functionName =
BuiltinFunctionName.ofAggregation(aggCall.getAggregation().getName()).get();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.response.agg;

import java.util.Collections;
import java.util.Map;
import lombok.Value;
import org.opensearch.search.SearchHit;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.metrics.TopHits;

/** {@link TopHits} metric parser for ARG_MAX/ARG_MIN aggregations. */
@Value
public class ArgMaxMinParser implements MetricParser {

String name;

@Override
public Map<String, Object> parse(Aggregation agg) {
TopHits topHits = (TopHits) agg;
SearchHit[] hits = topHits.getHits().getHits();

if (hits.length == 0) {
return Collections.singletonMap(agg.getName(), null);
}

Map<String, Object> source = hits[0].getSourceAsMap();

if (source.isEmpty()) {
return Collections.singletonMap(agg.getName(), null);
} else {
Object value = source.values().iterator().next();
return Collections.singletonMap(agg.getName(), value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.dialect.SparkSqlDialect;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.tools.Frameworks;
Expand All @@ -53,7 +52,7 @@ public class CalcitePPLAbstractTest {
public CalcitePPLAbstractTest(CalciteAssert.SchemaSpec... schemaSpecs) {
this.config = config(schemaSpecs);
this.planTransformer = new CalciteRelNodeVisitor();
this.converter = new RelToSqlConverter(SparkSqlDialect.DEFAULT);
this.converter = new RelToSqlConverter(OpenSearchSparkSqlDialect.DEFAULT);
this.settings = mock(Settings.class);
}

Expand Down Expand Up @@ -160,7 +159,7 @@ public void verifyPPLToSparkSQL(RelNode rel, String expected) {
String normalized = expected.replace("\n", System.lineSeparator());
SqlImplementor.Result result = converter.visitRoot(rel);
final SqlNode sqlNode = result.asStatement();
final String sql = sqlNode.toSqlString(SparkSqlDialect.DEFAULT).getSql();
final String sql = sqlNode.toSqlString(OpenSearchSparkSqlDialect.DEFAULT).getSql();
assertThat(sql, is(normalized));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void testEarliestWithoutSecondArgument() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MIN(`message`, `@timestamp`) `earliest_message`\n" + "FROM `POST`.`LOGS`";
"SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`\n" + "FROM `POST`.`LOGS`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -125,7 +125,7 @@ public void testLatestWithoutSecondArgument() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MAX(`message`, `@timestamp`) `latest_message`\n" + "FROM `POST`.`LOGS`";
"SELECT MAX_BY (`message`, `@timestamp`) `latest_message`\n" + "FROM `POST`.`LOGS`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -147,7 +147,7 @@ public void testEarliestByServerWithoutSecondArgument() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MIN(`message`, `@timestamp`) `earliest_message`, `server`\n"
"SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`, `server`\n"
+ "FROM `POST`.`LOGS`\n"
+ "GROUP BY `server`";
verifyPPLToSparkSQL(root, expectedSparkSql);
Expand All @@ -171,7 +171,7 @@ public void testLatestByServerWithoutSecondArgument() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MAX(`message`, `@timestamp`) `latest_message`, `server`\n"
"SELECT MAX_BY (`message`, `@timestamp`) `latest_message`, `server`\n"
+ "FROM `POST`.`LOGS`\n"
+ "GROUP BY `server`";
verifyPPLToSparkSQL(root, expectedSparkSql);
Expand All @@ -196,7 +196,7 @@ public void testEarliestWithOtherAggregatesWithoutSecondArgument() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MIN(`message`, `@timestamp`) `earliest_message`, "
"SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`, "
+ "COUNT(*) `cnt`, `server`\n"
+ "FROM `POST`.`LOGS`\n"
+ "GROUP BY `server`";
Expand All @@ -217,7 +217,7 @@ public void testEarliestWithExplicitTimestampField() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MIN(`message`, `created_at`) `earliest_message`\n" + "FROM `POST`.`LOGS`";
"SELECT MIN_BY (`message`, `created_at`) `earliest_message`\n" + "FROM `POST`.`LOGS`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -235,7 +235,7 @@ public void testLatestWithExplicitTimestampField() {
verifyResult(root, expectedResult);

String expectedSparkSql =
"SELECT ARG_MAX(`message`, `created_at`) `latest_message`\n" + "FROM `POST`.`LOGS`";
"SELECT MAX_BY (`message`, `created_at`) `latest_message`\n" + "FROM `POST`.`LOGS`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl.calcite;

import com.google.common.collect.ImmutableMap;
import java.util.Map;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.dialect.SparkSqlDialect;

/**
* Custom Spark SQL dialect that extends Calcite's SparkSqlDialect to handle OpenSearch-specific
* function translations. This dialect ensures that functions are translated to their correct Spark
* SQL equivalents.
*/
public class OpenSearchSparkSqlDialect extends SparkSqlDialect {

/** Singleton instance of the OpenSearch Spark SQL dialect. */
public static final OpenSearchSparkSqlDialect DEFAULT = new OpenSearchSparkSqlDialect();

private static final Map<String, String> CALCITE_TO_SPARK_MAPPING =
ImmutableMap.of(
"ARG_MIN", "MIN_BY",
"ARG_MAX", "MAX_BY");

private OpenSearchSparkSqlDialect() {
super(DEFAULT_CONTEXT);
}

@Override
public void unparseCall(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
String operatorName = call.getOperator().getName();

// Replace Calcite specific functions with their Spark SQL equivalents
if (CALCITE_TO_SPARK_MAPPING.containsKey(operatorName)) {
unparseFunction(
writer, call, CALCITE_TO_SPARK_MAPPING.get(operatorName), leftPrec, rightPrec);
} else {
super.unparseCall(writer, call, leftPrec, rightPrec);
}
}

private void unparseFunction(
SqlWriter writer, SqlCall call, String functionName, int leftPrec, int rightPrec) {
writer.keyword(functionName);
final SqlWriter.Frame frame = writer.startList("(", ")");
for (int i = 0; i < call.operandCount(); i++) {
if (i > 0) {
writer.sep(",");
}
call.operand(i).unparse(writer, leftPrec, rightPrec);
}
writer.endList(frame);
}
}
Loading