diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteSortCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteSortCommandIT.java index f5c51ed04e2..1867f7be6cb 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteSortCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteSortCommandIT.java @@ -22,10 +22,4 @@ public void init() throws IOException { @Ignore @Override public void testSortIpField() throws IOException {} - - // TODO: Fix incorrect results for NULL values, addressed by issue: - // https://github.com/opensearch-project/sql/issues/3375 - @Ignore - @Override - public void testSortWithNullValue() throws IOException {} } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLJoinIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLJoinIT.java index 3a23fc0c44a..24b3ab328d7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLJoinIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLJoinIT.java @@ -18,6 +18,7 @@ import org.junit.Ignore; import org.junit.Test; import org.opensearch.client.Request; +import org.opensearch.sql.legacy.TestsConstants; public class CalcitePPLJoinIT extends CalcitePPLIntegTestCase { @@ -29,22 +30,22 @@ public void init() throws IOException { loadIndex(Index.OCCUPATION); loadIndex(Index.HOBBIES); Request request1 = - new Request("PUT", "/opensearch-sql_test_index_state_country/_doc/5?refresh=true"); + new Request("PUT", "/" + TestsConstants.TEST_INDEX_STATE_COUNTRY + "/_doc/5?refresh=true"); request1.setJsonEntity( "{\"name\":\"Jim\",\"age\":27,\"state\":\"B.C\",\"country\":\"Canada\",\"year\":2023,\"month\":4}"); client().performRequest(request1); Request request2 = - new Request("PUT", "/opensearch-sql_test_index_state_country/_doc/6?refresh=true"); + new Request("PUT", "/" + TestsConstants.TEST_INDEX_STATE_COUNTRY + "/_doc/6?refresh=true"); request2.setJsonEntity( "{\"name\":\"Peter\",\"age\":57,\"state\":\"B.C\",\"country\":\"Canada\",\"year\":2023,\"month\":4}"); client().performRequest(request2); Request request3 = - new Request("PUT", "/opensearch-sql_test_index_state_country/_doc/7?refresh=true"); + new Request("PUT", "/" + TestsConstants.TEST_INDEX_STATE_COUNTRY + "/_doc/7?refresh=true"); request3.setJsonEntity( "{\"name\":\"Rick\",\"age\":70,\"state\":\"B.C\",\"country\":\"Canada\",\"year\":2023,\"month\":4}"); client().performRequest(request3); Request request4 = - new Request("PUT", "/opensearch-sql_test_index_state_country/_doc/8?refresh=true"); + new Request("PUT", "/" + TestsConstants.TEST_INDEX_STATE_COUNTRY + "/_doc/8?refresh=true"); request4.setJsonEntity( "{\"name\":\"David\",\"age\":40,\"state\":\"Washington\",\"country\":\"USA\",\"year\":2023,\"month\":4}"); client().performRequest(request4); @@ -213,9 +214,9 @@ public void testComplexLeftJoin() { actual, rows("Jane", 20, "Quebec", "Canada", "Scientist", "Canada", 90000), rows("John", 25, "Ontario", "Canada", "Doctor", "Canada", 120000), - rows("Jim", 27, "B.C", "Canada", null, null, 0), - rows("Peter", 57, "B.C", "Canada", null, null, 0), - rows("Rick", 70, "B.C", "Canada", null, null, 0)); + rows("Jim", 27, "B.C", "Canada", null, null, null), + rows("Peter", 57, "B.C", "Canada", null, null, null), + rows("Rick", 70, "B.C", "Canada", null, null, null)); } @Test @@ -240,10 +241,10 @@ public void testComplexRightJoin() { actual, rows("Jane", 20, "Quebec", "Canada", "Scientist", "Canada", 90000), rows("John", 25, "Ontario", "Canada", "Doctor", "Canada", 120000), - rows(null, 0, null, null, "Engineer", "England", 100000), - rows(null, 0, null, null, "Artist", "USA", 70000), - rows(null, 0, null, null, "Doctor", "USA", 120000), - rows(null, 0, null, null, "Unemployed", "Canada", 0)); + rows(null, null, null, null, "Engineer", "England", 100000), + rows(null, null, null, null, "Artist", "USA", 70000), + rows(null, null, null, null, "Doctor", "USA", 120000), + rows(null, null, null, null, "Unemployed", "Canada", 0)); } @Test @@ -392,6 +393,54 @@ public void testMultipleJoins() { TEST_INDEX_STATE_COUNTRY)); } + @Ignore // TODO seems a calcite bug + public void testMultipleJoinsWithRelationSubquery() { + JSONObject actual = + executeQuery( + String.format( + """ + source = %s + | where country = 'Canada' OR country = 'England' + | inner join left=a, right=b + ON a.name = b.name AND a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + [ + source = %s + ] + | eval a_name = a.name + | eval a_country = a.country + | eval b_country = b.country + | fields a_name, age, state, a_country, occupation, b_country, salary + | left join left=a, right=b + ON a.a_name = b.name + [ + source = %s + ] + | eval aa_country = a.a_country + | eval ab_country = a.b_country + | eval bb_country = b.country + | fields a_name, age, state, aa_country, occupation, ab_country, salary, bb_country, hobby, language + | cross join left=a, right=b + [ + source = %s + ] + | eval new_country = a.aa_country + | eval new_salary = b.salary + | stats avg(new_salary) as avg_salary by span(age, 5) as age_span, state + | left semi join left=a, right=b + ON a.state = b.state + [ + source = %s + ] + | eval new_avg_salary = floor(avg_salary) + | fields state, age_span, new_avg_salary + """, + TEST_INDEX_STATE_COUNTRY, + TEST_INDEX_OCCUPATION, + TEST_INDEX_HOBBIES, + TEST_INDEX_OCCUPATION, + TEST_INDEX_STATE_COUNTRY)); + } + @Test public void testMultipleJoinsWithoutTableAliases() { JSONObject actual = @@ -442,7 +491,7 @@ public void testMultipleJoinsWithPartTableAliases() { } @Test - public void testMultipleJoinsWithSelfJoin1() { + public void testMultipleJoinsWithSelfJoin() { JSONObject actual = executeQuery( String.format( @@ -469,8 +518,8 @@ public void testMultipleJoinsWithSelfJoin1() { rows("John", "John", "John", "John")); } - @Ignore // TODO table subquery not support - public void testMultipleJoinsWithSelfJoin2() { + @Test + public void testMultipleJoinsWithSubquerySelfJoin() { JSONObject actual = executeQuery( String.format( @@ -481,10 +530,24 @@ public void testMultipleJoinsWithSelfJoin2() { TEST_INDEX_OCCUPATION, TEST_INDEX_HOBBIES, TEST_INDEX_STATE_COUNTRY)); + verifySchema( + actual, + schema("name", "string"), + schema("name0", "string"), + schema("name1", "string"), + schema("name2", "string")); + verifyDataRows( + actual, + rows("David", "David", "David", "David"), + rows("David", "David", "David", "David"), + rows("Hello", "Hello", "Hello", "Hello"), + rows("Jake", "Jake", "Jake", "Jake"), + rows("Jane", "Jane", "Jane", "Jane"), + rows("John", "John", "John", "John")); } @Test - public void testCheckAccessTheReferenceByAliases1() { + public void testCheckAccessTheReferenceByAliases() { String res1 = execute( String.format( @@ -520,8 +583,8 @@ public void testCheckAccessTheReferenceByAliases1() { assertEquals(res4, res5); } - @Ignore // TODO table subquery not support - public void testCheckAccessTheReferenceByAliases2() { + @Test + public void testCheckAccessTheReferenceBySubqueryAliases() { String res1 = execute( String.format( @@ -559,7 +622,7 @@ public void testCheckAccessTheReferenceByAliases2() { } @Test - public void testCheckAccessTheReferenceByAliases3() { + public void testCheckAccessTheReferenceByOverrideAliases() { String res1 = execute( String.format( @@ -581,4 +644,109 @@ public void testCheckAccessTheReferenceByAliases3() { assertEquals(res1, res2); assertEquals(res1, res3); } + + @Test + public void testCheckAccessTheReferenceByOverrideSubqueryAliases() { + String res1 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s as tt ]" + + " | fields tt.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + String res2 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s as tt ]" + + " as t2 | fields tt.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + String res3 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s ] as tt" + + " | fields tt.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + assertEquals(res1, res2); + assertEquals(res1, res3); + } + + @Test + public void testCheckAccessTheReferenceByOverrideSubqueryAliases2() { + String res1 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s as tt ]" + + " | fields t2.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + String res2 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s as tt ]" + + " as t2 | fields t2.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + String res3 = + execute( + String.format( + "source = %s | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = %s ] as tt" + + " | fields t2.name", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + assertEquals(res1, res2); + assertEquals(res1, res3); + } + + @Test + public void testInnerJoinWithRelationSubquery() { + JSONObject actual = + executeQuery( + String.format( + """ + source = %s + | where country = 'USA' OR country = 'England' + | inner join left=a, right=b + ON a.name = b.name + [ + source = %s + | where salary > 0 + | fields name, country, salary + | sort salary + | head 3 + ] + | stats avg(salary) by span(age, 10) as age_span, b.country + """, + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifySchema( + actual, + schema("b.country", "string"), + schema("age_span", "double"), + schema("avg(salary)", "double")); + verifyDataRows(actual, rows("USA", 30, 70000.0), rows("England", 70, 100000)); + } + + @Test + public void testLeftJoinWithRelationSubquery() { + JSONObject actual = + executeQuery( + String.format( + """ + source = %s + | where country = 'USA' OR country = 'England' + | left join left=a, right=b + ON a.name = b.name + [ + source = %s + | where salary > 0 + | fields name, country, salary + | sort salary + | head 3 + ] + | stats avg(salary) by span(age, 10) as age_span, b.country + """, + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifySchema( + actual, + schema("b.country", "string"), + schema("age_span", "double"), + schema("avg(salary)", "double")); + verifyDataRows( + actual, rows("USA", 30, 70000.0), rows("England", 70, 100000), rows(null, 40, 0)); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLSortIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLSortIT.java index eb582f10882..26b1729c12f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLSortIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLSortIT.java @@ -6,9 +6,11 @@ package org.opensearch.sql.calcite.standalone; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyOrder; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; @@ -22,6 +24,7 @@ public void init() throws IOException { super.init(); loadIndex(Index.BANK); + loadIndex(Index.BANK_WITH_NULL_VALUES); } @Test @@ -191,4 +194,22 @@ public void testSortAgeNameAndFieldsNameAge() { rows("Amber JOHnny", 32), rows("Nanette", 28)); } + + @Test + public void testSortWithNullValue() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | fields firstname, balance", + TEST_INDEX_BANK_WITH_NULL_VALUES)); + verifyOrder( + result, + rows("Dale", 4180), + rows("Nanette", 32838), + rows("Amber JOHnny", 39225), + rows("Dillard", 48086), + rows("Hattie", null), + rows("Elinor", null), + rows("Virginia", null)); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java index 9a0f1c3a863..1c84447ab09 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java @@ -167,7 +167,7 @@ public CalciteOpenSearchIndexScan pushDownFilter(Filter filter) { // TODO: handle the case where condition contains a score function return newScan; } catch (Exception e) { - LOG.warn("Cannot analyze the filter condition {}", filter.getCondition(), e); + LOG.warn("Cannot pushdown the filter condition {}, ", filter.getCondition()); } return null; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/util/JdbcOpenSearchDataTypeConvertor.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/util/JdbcOpenSearchDataTypeConvertor.java index 8407ecdc307..a1f14c32ea0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/util/JdbcOpenSearchDataTypeConvertor.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/util/JdbcOpenSearchDataTypeConvertor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.opensearch.util; +import java.sql.Array; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; @@ -58,55 +59,67 @@ public static ExprType getExprTypeFromSqlType(int sqlType) { public static ExprValue getExprValueFromSqlType( ResultSet rs, int i, int sqlType, RelDataType fieldType) throws SQLException { - Object value; - switch (sqlType) { - case Types.VARCHAR: - case Types.CHAR: - case Types.LONGVARCHAR: - value = rs.getString(i); - break; - case Types.INTEGER: - value = rs.getInt(i); - break; - case Types.BIGINT: - value = rs.getLong(i); - break; - case Types.DECIMAL: - case Types.NUMERIC: - value = rs.getBigDecimal(i); - break; - case Types.DOUBLE: - value = rs.getDouble(i); - break; - case Types.FLOAT: - value = rs.getFloat(i); - break; - case Types.DATE: - value = rs.getString(i); - return value == null ? ExprNullValue.of() : new ExprDateValue((String) value); - case Types.TIME: - value = rs.getString(i); - return value == null ? ExprNullValue.of() : new ExprTimeValue((String) value); - case Types.TIMESTAMP: - value = rs.getString(i); - return value == null ? ExprNullValue.of() : new ExprTimestampValue((String) value); - case Types.BOOLEAN: - value = rs.getBoolean(i); - break; - case Types.ARRAY: - value = rs.getArray(i); - // For calcite - if (value instanceof ArrayImpl) { - value = Arrays.asList((Object[]) ((ArrayImpl) value).getArray()); - } - break; - default: - value = rs.getObject(i); - LOG.warn( - "Unchecked sql type: {}, return Object type {}", - sqlType, - value.getClass().getTypeName()); + Object value = rs.getObject(i); + if (value == null) { + return ExprNullValue.of(); + } + + try { + switch (sqlType) { + case Types.VARCHAR: + case Types.CHAR: + case Types.LONGVARCHAR: + return ExprValueUtils.fromObjectValue(rs.getString(i)); + + case Types.INTEGER: + return ExprValueUtils.fromObjectValue(rs.getInt(i)); + + case Types.BIGINT: + return ExprValueUtils.fromObjectValue(rs.getLong(i)); + + case Types.DECIMAL: + case Types.NUMERIC: + return ExprValueUtils.fromObjectValue(rs.getBigDecimal(i)); + + case Types.DOUBLE: + return ExprValueUtils.fromObjectValue(rs.getDouble(i)); + + case Types.FLOAT: + return ExprValueUtils.fromObjectValue(rs.getFloat(i)); + + case Types.DATE: + String dateStr = rs.getString(i); + return new ExprDateValue(dateStr); + + case Types.TIME: + String timeStr = rs.getString(i); + return new ExprTimeValue(timeStr); + + case Types.TIMESTAMP: + String timestampStr = rs.getString(i); + return new ExprTimestampValue(timestampStr); + + case Types.BOOLEAN: + return ExprValueUtils.fromObjectValue(rs.getBoolean(i)); + + case Types.ARRAY: + Array array = rs.getArray(i); + if (array instanceof ArrayImpl) { + return ExprValueUtils.fromObjectValue( + Arrays.asList((Object[]) ((ArrayImpl) value).getArray())); + } + return ExprValueUtils.fromObjectValue(array); + + default: + LOG.warn( + "Unchecked sql type: {}, return Object type {}", + sqlType, + value.getClass().getTypeName()); + return ExprValueUtils.fromObjectValue(value); + } + } catch (SQLException e) { + LOG.error("Error converting SQL type {}: {}", sqlType, e.getMessage()); + throw e; } - return value == null ? ExprNullValue.of() : ExprValueUtils.fromObjectValue(value); } } diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 17801cc7ead..7d9b64e793c 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -202,19 +202,24 @@ mlArg // clauses fromClause - : SOURCE EQUAL tableSourceClause - | INDEX EQUAL tableSourceClause + : SOURCE EQUAL tableOrSubqueryClause + | INDEX EQUAL tableOrSubqueryClause | SOURCE EQUAL tableFunction | INDEX EQUAL tableFunction ; +tableOrSubqueryClause + : LT_SQR_PRTHS subSearch RT_SQR_PRTHS (AS alias = qualifiedName)? + | tableSourceClause + ; + tableSourceClause : tableSource (COMMA tableSource)* (AS alias = qualifiedName)? ; // join joinCommand - : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableSourceClause + : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableOrSubqueryClause ; joinType diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 37f6696e0fc..b0800574eb3 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -10,7 +10,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DescribeCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FromClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.HeadCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RareCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RenameCommandContext; @@ -166,14 +165,15 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) : Optional.empty(); Optional rightAlias = Optional.empty(); - if (ctx.tableSourceClause().alias != null) { - rightAlias = Optional.of(internalVisitExpression(ctx.tableSourceClause().alias).toString()); + if (ctx.tableOrSubqueryClause().alias != null) { + rightAlias = + Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString()); } if (ctx.sideAlias().rightAlias != null) { rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString()); } - UnresolvedPlan rightRelation = visit(ctx.tableSourceClause()); + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); // Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias // existing in right. UnresolvedPlan right; @@ -406,11 +406,14 @@ public UnresolvedPlan visitTopCommand(TopCommandContext ctx) { groupList); } - /** From clause. */ @Override - public UnresolvedPlan visitFromClause(FromClauseContext ctx) { - if (ctx.tableFunction() != null) { - return visitTableFunction(ctx.tableFunction()); + public UnresolvedPlan visitTableOrSubqueryClause( + OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) { + if (ctx.subSearch() != null) { + return ctx.alias != null + ? new SubqueryAlias( + internalVisitExpression(ctx.alias).toString(), visitSubSearch(ctx.subSearch())) + : visitSubSearch(ctx.subSearch()); } else { return visitTableSourceClause(ctx.tableSourceClause()); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLBasicTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLBasicTest.java index 461b1bc265d..2f6de68d671 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLBasicTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLBasicTest.java @@ -337,4 +337,48 @@ public void testBlockComments() { + " LogicalTableScan(table=[[scott, products_temporal]])\n"; verifyLogical(getRelNode(ppl3), expectedLogical3); } + + @Test + public void testTableAlias() { + String ppl = + "source=EMP as e | where (e.DEPTNO = 20 or e.MGR = 30) and e.SAL > 1000 | fields e.EMPNO," + + " e.ENAME"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[AND(OR(=($7, 20), =($3, 30)), >($5, 1000))])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT `EMPNO`, `ENAME`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE (`DEPTNO` = 20 OR `MGR` = 30) AND `SAL` > 1000"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testRelationSubqueryAlias() { + String ppl = "source=EMP as e | join on e.DEPTNO = d.DEPTNO [ source=DEPT | head 10 ] as d"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN (SELECT `DEPTNO`, `DNAME`, `LOC`\n" + + "FROM `scott`.`DEPT`\n" + + "LIMIT 10) `t` ON `EMP`.`DEPTNO` = `t`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java index feb8ff1607d..5e0e9d93a62 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java @@ -407,4 +407,233 @@ public void testMultipleJoinWithSelfJoin() { + "INNER JOIN `scott`.`EMP` `EMP0` ON `EMP`.`DEPTNO` = `EMP0`.`DEPTNO`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + // +-----------------------------+ + // | join with relation subquery | + // +-----------------------------+ + + @Test + public void testJoinWithRelationSubquery() { + String ppl = + """ + source=EMP | join left = t1 right = t2 ON t1.DEPTNO = t2.DEPTNO + [ + source = DEPT + | where DEPTNO > 10 and LOC = 'CHICAGO' + | fields DEPTNO, DNAME + | sort - DEPTNO + | head 10 + ] + | stats count(MGR) as cnt by JOB + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{2}], cnt=[COUNT($3)])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(DEPTNO=[$0], DNAME=[$1])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO'))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + "JOB=SALESMAN; cnt=4\n" + "JOB=CLERK; cnt=1\n" + "JOB=MANAGER; cnt=1\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `EMP`.`JOB`, COUNT(`EMP`.`MGR`) `cnt`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN (SELECT `DEPTNO`, `DNAME`\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `DEPTNO` > 10 AND `LOC` = 'CHICAGO'\n" + + "ORDER BY `DEPTNO` DESC NULLS FIRST\n" + + "LIMIT 10) `t1` ON `EMP`.`DEPTNO` = `t1`.`DEPTNO`\n" + + "GROUP BY `EMP`.`JOB`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testMultipleJoinsWithRelationSubquery() { + String ppl = + """ + source=EMP + | head 10 + | inner join left = l right = r ON l.DEPTNO = r.DEPTNO + [ + source = DEPT + | where DEPTNO > 10 and LOC = 'CHICAGO' + ] + | left join left = l right = r ON l.JOB = r.JOB + [ + source = BONUS + | where JOB = 'SALESMAN' + ] + | cross join left = l right = r + [ + source = SALGRADE + | where LOSAL <= 1500 + | sort - GRADE + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalJoin(condition=[true], joinType=[inner])\n" + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO'))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN')])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 15); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LIMIT 10) `t`\n" + + "INNER JOIN (SELECT *\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `DEPTNO` > 10 AND `LOC` = 'CHICAGO') `t0` ON `t`.`DEPTNO` = `t0`.`DEPTNO`\n" + + "LEFT JOIN (SELECT *\n" + + "FROM `scott`.`BONUS`\n" + + "WHERE `JOB` = 'SALESMAN') `t1` ON `t`.`JOB` = `t1`.`JOB`\n" + + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `LOSAL` <= 1500\n" + + "ORDER BY `GRADE` DESC NULLS FIRST) `t3`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testMultipleJoinsWithRelationSubqueryWithAlias() { + String ppl = + """ + source=EMP as t1 + | head 10 + | inner join ON t1.DEPTNO = t2.DEPTNO + [ + source = DEPT as t2 + | where DEPTNO > 10 and LOC = 'CHICAGO' + ] + | left join ON t1.JOB = t3.JOB + [ + source = BONUS as t3 + | where JOB = 'SALESMAN' + ] + | cross join + [ + source = SALGRADE as t4 + | where LOSAL <= 1500 + | sort - GRADE + ] + """; + + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalJoin(condition=[true], joinType=[inner])\n" + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO'))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN')])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + verifyLogical(root, expectedLogical); + + verifyResultCount(root, 15); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LIMIT 10) `t`\n" + + "INNER JOIN (SELECT *\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `DEPTNO` > 10 AND `LOC` = 'CHICAGO') `t0` ON `t`.`DEPTNO` = `t0`.`DEPTNO`\n" + + "LEFT JOIN (SELECT *\n" + + "FROM `scott`.`BONUS`\n" + + "WHERE `JOB` = 'SALESMAN') `t1` ON `t`.`JOB` = `t1`.`JOB`\n" + + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `LOSAL` <= 1500\n" + + "ORDER BY `GRADE` DESC NULLS FIRST) `t3`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testMultipleJoinsWithRelationSubqueryWithAlias2() { + String ppl = + """ + source=EMP as t1 + | head 10 + | inner join left = l right = r ON t1.DEPTNO = t2.DEPTNO + [ + source = DEPT as t2 + | where DEPTNO > 10 and LOC = 'CHICAGO' + ] + | left join left = l right = r ON t1.JOB = t3.JOB + [ + source = BONUS as t3 + | where JOB = 'SALESMAN' + ] + | cross join + [ + source = SALGRADE as t4 + | where LOSAL <= 1500 + | sort - GRADE + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalJoin(condition=[true], joinType=[inner])\n" + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO'))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN')])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + verifyLogical(root, expectedLogical); + + verifyResultCount(root, 15); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LIMIT 10) `t`\n" + + "INNER JOIN (SELECT *\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `DEPTNO` > 10 AND `LOC` = 'CHICAGO') `t0` ON `t`.`DEPTNO` = `t0`.`DEPTNO`\n" + + "LEFT JOIN (SELECT *\n" + + "FROM `scott`.`BONUS`\n" + + "WHERE `JOB` = 'SALESMAN') `t1` ON `t`.`JOB` = `t1`.`JOB`\n" + + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `LOSAL` <= 1500\n" + + "ORDER BY `GRADE` DESC NULLS FIRST) `t3`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } }