diff --git a/client/trino-cli/src/main/java/io/trino/cli/lexer/DelimiterLexer.java b/client/trino-cli/src/main/java/io/trino/cli/lexer/DelimiterLexer.java index ccc96dd15044..e4b0ce50aa55 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/lexer/DelimiterLexer.java +++ b/client/trino-cli/src/main/java/io/trino/cli/lexer/DelimiterLexer.java @@ -21,6 +21,7 @@ import org.antlr.v4.runtime.LexerNoViableAltException; import org.antlr.v4.runtime.Token; +import java.util.HashSet; import java.util.Set; /** @@ -30,17 +31,26 @@ * The code in nextToken() is a copy of the implementation in org.antlr.v4.runtime.Lexer, with a * bit added to match the token before the default behavior is invoked. */ -class DelimiterLexer +public class DelimiterLexer extends SqlBaseLexer { private final Set delimiters; + private final boolean useSemicolon; public DelimiterLexer(CharStream input, Set delimiters) { super(input); + delimiters = new HashSet<>(delimiters); + this.useSemicolon = delimiters.remove(";"); this.delimiters = ImmutableSet.copyOf(delimiters); } + public boolean isDelimiter(Token token) + { + return (token.getType() == SqlBaseParser.DELIMITER) || + (useSemicolon && (token.getType() == SEMICOLON)); + } + @Override public Token nextToken() { diff --git a/client/trino-cli/src/main/java/io/trino/cli/lexer/StatementSplitter.java b/client/trino-cli/src/main/java/io/trino/cli/lexer/StatementSplitter.java index 86f922d97488..01b960b63142 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/lexer/StatementSplitter.java +++ b/client/trino-cli/src/main/java/io/trino/cli/lexer/StatementSplitter.java @@ -15,11 +15,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.grammar.sql.SqlBaseBaseVisitor; import io.trino.grammar.sql.SqlBaseLexer; import io.trino.grammar.sql.SqlBaseParser; +import io.trino.grammar.sql.SqlBaseParser.FunctionSpecificationContext; import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.TokenSource; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.RuleNode; import java.util.List; import java.util.Objects; @@ -39,25 +45,50 @@ public StatementSplitter(String sql) public StatementSplitter(String sql, Set delimiters) { - TokenSource tokens = getLexer(sql, delimiters); + DelimiterLexer lexer = getLexer(sql, delimiters); + CommonTokenStream tokenStream = new CommonTokenStream(lexer); + tokenStream.fill(); + + SqlBaseParser parser = new SqlBaseParser(tokenStream); + parser.removeErrorListeners(); + ImmutableList.Builder list = ImmutableList.builder(); StringBuilder sb = new StringBuilder(); - while (true) { - Token token = tokens.nextToken(); - if (token.getType() == Token.EOF) { - break; - } - if (token.getType() == SqlBaseParser.DELIMITER) { - String statement = sb.toString().trim(); - if (!statement.isEmpty()) { - list.add(new Statement(statement, token.getText())); + int index = 0; + + while (index < tokenStream.size()) { + ParserRuleContext context = parser.statement(); + + if (containsFunction(context)) { + Token stop = context.getStop(); + if ((stop != null) && (stop.getTokenIndex() >= index)) { + int endIndex = stop.getTokenIndex(); + while (index <= endIndex) { + Token token = tokenStream.get(index); + index++; + sb.append(token.getText()); + } } - sb = new StringBuilder(); } - else { + + while (index < tokenStream.size()) { + Token token = tokenStream.get(index); + index++; + if (token.getType() == Token.EOF) { + break; + } + if (lexer.isDelimiter(token)) { + String statement = sb.toString().trim(); + if (!statement.isEmpty()) { + list.add(new Statement(statement, token.getText())); + } + sb = new StringBuilder(); + break; + } sb.append(token.getText()); } } + this.completeStatements = list.build(); this.partialStatement = sb.toString().trim(); } @@ -105,12 +136,42 @@ public static boolean isEmptyStatement(String sql) } } - public static TokenSource getLexer(String sql, Set terminators) + public static DelimiterLexer getLexer(String sql, Set terminators) { requireNonNull(sql, "sql is null"); return new DelimiterLexer(CharStreams.fromString(sql), terminators); } + private static boolean containsFunction(ParseTree tree) + { + return new SqlBaseBaseVisitor() + { + @Override + protected Boolean defaultResult() + { + return false; + } + + @Override + protected Boolean aggregateResult(Boolean aggregate, Boolean nextResult) + { + return aggregate || nextResult; + } + + @Override + protected boolean shouldVisitNextChild(RuleNode node, Boolean currentResult) + { + return !currentResult; + } + + @Override + public Boolean visitFunctionSpecification(FunctionSpecificationContext context) + { + return true; + } + }.visit(tree); + } + public static class Statement { private final String statement; diff --git a/client/trino-cli/src/test/java/io/trino/cli/lexer/TestStatementSplitter.java b/client/trino-cli/src/test/java/io/trino/cli/lexer/TestStatementSplitter.java index 52037fa7d7d7..6b219c1a73cf 100644 --- a/client/trino-cli/src/test/java/io/trino/cli/lexer/TestStatementSplitter.java +++ b/client/trino-cli/src/test/java/io/trino/cli/lexer/TestStatementSplitter.java @@ -284,6 +284,151 @@ public void testSplitterSelectItemsWithoutComma() assertThat(splitter.getPartialStatement()).isEmpty(); } + @Test + public void testSplitterSimpleInlineFunction() + { + String function = "WITH FUNCTION abc() RETURNS int RETURN 42 SELECT abc() FROM t"; + String sql = function + "; SELECT 456;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 456")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterSimpleInlineFunctionWithIncompleteSelect() + { + String function = "WITH FUNCTION abc() RETURNS int RETURN 42 SELECT abc(), FROM t"; + String sql = function + "; SELECT 456;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 456")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterSimpleInlineFunctionWithComments() + { + String function = "/* start */ WITH FUNCTION abc() RETURNS int /* middle */ RETURN 42 SELECT abc() FROM t /* end */"; + String sql = function + "; SELECT 456;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 456")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterCreateFunction() + { + String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END"; + String sql = function + "; SELECT 123;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 123")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterCreateFunctionInvalidThen() + { + String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN oops; END IF; RETURN 1; END"; + String sql = function + "; SELECT 123;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 123")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterCreateFunctionInvalidReturn() + { + String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN oops; END IF; RETURN 1 xxx; END"; + String sql = function + "; SELECT 123;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 123")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterCreateFunctionInvalidBegin() + { + String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN xxx IF false THEN oops; END IF; RETURN 1; END"; + String sql = function + "; SELECT 123;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement("CREATE FUNCTION fib(n int) RETURNS int BEGIN xxx IF false THEN oops; END IF"), + statement("RETURN 1"), + statement("END"), + statement("SELECT 123")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterCreateFunctionInvalidDelimitedThen() + { + String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN; oops; END IF; RETURN 1; END"; + String sql = function + "; SELECT 123;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 123")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterComplexCreateFunction() + { + String function = "" + + "CREATE FUNCTION fib(n bigint)\n" + + "RETURNS bigint\n" + + "BEGIN\n" + + " DECLARE a bigint DEFAULT 1;\n" + + " DECLARE b bigint DEFAULT 1;\n" + + " DECLARE c bigint;\n" + + " IF n <= 2 THEN\n" + + " RETURN 1;\n" + + " END IF;\n" + + " WHILE n > 2 DO\n" + + " SET n = n - 1;\n" + + " SET c = a + b;\n" + + " SET a = b;\n" + + " SET b = c;\n" + + " END WHILE;\n" + + " RETURN c;\n" + + "END"; + String sql = function + ";\nSELECT 123;\nSELECT 456;\n"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement(function), + statement("SELECT 123"), + statement("SELECT 456")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + + @Test + public void testSplitterMultipleFunctions() + { + String function1 = "CREATE FUNCTION f1() RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END"; + String function2 = "CREATE FUNCTION f2() RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END"; + String sql = "SELECT 11;" + function1 + ";" + function2 + ";SELECT 22;" + function2 + ";SELECT 33;"; + StatementSplitter splitter = new StatementSplitter(sql); + assertThat(splitter.getCompleteStatements()).containsExactly( + statement("SELECT 11"), + statement(function1), + statement(function2), + statement("SELECT 22"), + statement(function2), + statement("SELECT 33")); + assertThat(splitter.getPartialStatement()).isEmpty(); + } + @Test public void testIsEmptyStatement() { diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 2fcacc2f4062..0f01927a5023 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -40,8 +40,12 @@ standaloneRowPattern : rowPattern EOF ; +standaloneFunctionSpecification + : functionSpecification EOF + ; + statement - : query #statementDefault + : rootQuery #statementDefault | USE schema=identifier #use | USE catalog=identifier '.' schema=identifier #use | CREATE CATALOG (IF NOT EXISTS)? catalog=identifier @@ -60,14 +64,14 @@ statement | CREATE (OR REPLACE)? TABLE (IF NOT EXISTS)? qualifiedName columnAliases? (COMMENT string)? - (WITH properties)? AS (query | '('query')') + (WITH properties)? AS (rootQuery | '('rootQuery')') (WITH (NO)? DATA)? #createTableAsSelect | CREATE (OR REPLACE)? TABLE (IF NOT EXISTS)? qualifiedName '(' tableElement (',' tableElement)* ')' (COMMENT string)? (WITH properties)? #createTable | DROP TABLE (IF EXISTS)? qualifiedName #dropTable - | INSERT INTO qualifiedName columnAliases? query #insertInto + | INSERT INTO qualifiedName columnAliases? rootQuery #insertInto | DELETE FROM qualifiedName (WHERE booleanExpression)? #delete | TRUNCATE TABLE qualifiedName #truncateTable | COMMENT ON TABLE qualifiedName IS (string | NULL) #commentTable @@ -95,10 +99,10 @@ statement (IF NOT EXISTS)? qualifiedName (GRACE PERIOD interval)? (COMMENT string)? - (WITH properties)? AS query #createMaterializedView + (WITH properties)? AS rootQuery #createMaterializedView | CREATE (OR REPLACE)? VIEW qualifiedName (COMMENT string)? - (SECURITY (DEFINER | INVOKER))? AS query #createView + (SECURITY (DEFINER | INVOKER))? AS rootQuery #createView | REFRESH MATERIALIZED VIEW qualifiedName #refreshMaterializedView | DROP MATERIALIZED VIEW (IF EXISTS)? qualifiedName #dropMaterializedView | ALTER MATERIALIZED VIEW (IF EXISTS)? from=qualifiedName @@ -109,6 +113,8 @@ statement | ALTER VIEW from=qualifiedName RENAME TO to=qualifiedName #renameView | ALTER VIEW from=qualifiedName SET AUTHORIZATION principal #setViewAuthorization | CALL qualifiedName '(' (callArgument (',' callArgument)*)? ')' #call + | CREATE (OR REPLACE)? functionSpecification #createFunction + | DROP FUNCTION (IF EXISTS)? functionDeclaration #dropFunction | CREATE ROLE name=identifier (WITH ADMIN grantor)? (IN catalog=identifier)? #createRole @@ -157,7 +163,7 @@ statement | SHOW COLUMNS (FROM | IN) qualifiedName? (LIKE pattern=string (ESCAPE escape=string)?)? #showColumns | SHOW STATS FOR qualifiedName #showStats - | SHOW STATS FOR '(' query ')' #showStatsForQuery + | SHOW STATS FOR '(' rootQuery ')' #showStatsForQuery | SHOW CURRENT? ROLES ((FROM | IN) identifier)? #showRoles | SHOW ROLE GRANTS ((FROM | IN) identifier)? #showRoleGrants | DESCRIBE qualifiedName #showColumns @@ -188,8 +194,16 @@ statement USING relation ON expression mergeCase+ #merge ; +rootQuery + : withFunction? query + ; + +withFunction + : WITH functionSpecification (',' functionSpecification)* + ; + query - : with? queryNoWith + : with? queryNoWith ; with @@ -841,6 +855,65 @@ pathSpecification : pathElement (',' pathElement)* ; +functionSpecification + : FUNCTION functionDeclaration returnsClause routineCharacteristic* controlStatement + ; + +functionDeclaration + : qualifiedName '(' (parameterDeclaration (',' parameterDeclaration)*)? ')' + ; + +parameterDeclaration + : identifier? type + ; + +returnsClause + : RETURNS type + ; + +routineCharacteristic + : LANGUAGE identifier #languageCharacteristic + | NOT? DETERMINISTIC #deterministicCharacteristic + | RETURNS NULL ON NULL INPUT #returnsNullOnNullInputCharacteristic + | CALLED ON NULL INPUT #calledOnNullInputCharacteristic + | SECURITY (DEFINER | INVOKER) #securityCharacteristic + | COMMENT string #commentCharacteristic + ; + +controlStatement + : RETURN valueExpression #returnStatement + | SET identifier EQ expression #assignmentStatement + | CASE expression caseStatementWhenClause+ elseClause? END CASE #simpleCaseStatement + | CASE caseStatementWhenClause+ elseClause? END CASE #searchedCaseStatement + | IF expression THEN sqlStatementList elseIfClause* elseClause? END IF #ifStatement + | ITERATE identifier #iterateStatement + | LEAVE identifier #leaveStatement + | BEGIN (variableDeclaration SEMICOLON)* sqlStatementList? END #compoundStatement + | (label=identifier ':')? LOOP sqlStatementList END LOOP #loopStatement + | (label=identifier ':')? WHILE expression DO sqlStatementList END WHILE #whileStatement + | (label=identifier ':')? REPEAT sqlStatementList UNTIL expression END REPEAT #repeatStatement + ; + +caseStatementWhenClause + : WHEN expression THEN sqlStatementList + ; + +elseIfClause + : ELSEIF expression THEN sqlStatementList + ; + +elseClause + : ELSE sqlStatementList + ; + +variableDeclaration + : DECLARE identifier (',' identifier)* type (DEFAULT valueExpression)? + ; + +sqlStatementList + : (controlStatement SEMICOLON)+ + ; + privilege : CREATE | SELECT | DELETE | INSERT | UPDATE ; @@ -896,29 +969,29 @@ authorizationUser nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved : ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION - | BERNOULLI | BOTH - | CALL | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT - | DATA | DATE | DAY | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DISTRIBUTED | DOUBLE - | EMPTY | ENCODING | ERROR | EXCLUDING | EXPLAIN - | FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTIONS + | BEGIN | BERNOULLI | BOTH + | CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT + | DATA | DATE | DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETERMINISTIC | DISTRIBUTED | DO | DOUBLE + | ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXPLAIN + | FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS | GRACE | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS | HOUR - | IF | IGNORE | IMMEDIATE | INCLUDING | INITIAL | INPUT | INTERVAL | INVOKER | IO | ISOLATION + | IF | IGNORE | IMMEDIATE | INCLUDING | INITIAL | INPUT | INTERVAL | INVOKER | IO | ITERATE | ISOLATION | JSON | KEEP | KEY | KEYS - | LAST | LATERAL | LEADING | LEVEL | LIMIT | LOCAL | LOGICAL + | LANGUAGE | LAST | LATERAL | LEADING | LEAVE | LEVEL | LIMIT | LOCAL | LOGICAL | LOOP | MAP | MATCH | MATCHED | MATCHES | MATCH_RECOGNIZE | MATERIALIZED | MEASURES | MERGE | MINUTE | MONTH | NESTED | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OBJECT | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | OVERFLOW | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD | PERMUTE | PLAN | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE | QUOTES - | RANGE | READ | REFRESH | RENAME | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | RETURNING | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | RUNNING + | RANGE | READ | REFRESH | RENAME | REPEAT | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | RETURN | RETURNING | RETURNS | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | RUNNING | SCALAR | SCHEMA | SCHEMAS | SECOND | SECURITY | SEEK | SERIALIZABLE | SESSION | SET | SETS | SHOW | SOME | START | STATS | SUBSET | SUBSTRING | SYSTEM | TABLES | TABLESAMPLE | TEXT | TEXT_STRING | TIES | TIME | TIMESTAMP | TO | TRAILING | TRANSACTION | TRUNCATE | TRY_CAST | TYPE - | UNBOUNDED | UNCOMMITTED | UNCONDITIONAL | UNIQUE | UNKNOWN | UNMATCHED | UPDATE | USE | USER | UTF16 | UTF32 | UTF8 + | UNBOUNDED | UNCOMMITTED | UNCONDITIONAL | UNIQUE | UNKNOWN | UNMATCHED | UNTIL | UPDATE | USE | USER | UTF16 | UTF32 | UTF8 | VALIDATE | VALUE | VERBOSE | VERSION | VIEW - | WINDOW | WITHIN | WITHOUT | WORK | WRAPPER | WRITE + | WHILE | WINDOW | WITHIN | WITHOUT | WORK | WRAPPER | WRITE | YEAR | ZONE ; @@ -937,11 +1010,13 @@ AS: 'AS'; ASC: 'ASC'; AT: 'AT'; AUTHORIZATION: 'AUTHORIZATION'; +BEGIN: 'BEGIN'; BERNOULLI: 'BERNOULLI'; BETWEEN: 'BETWEEN'; BOTH: 'BOTH'; BY: 'BY'; CALL: 'CALL'; +CALLED: 'CALLED'; CASCADE: 'CASCADE'; CASE: 'CASE'; CAST: 'CAST'; @@ -972,6 +1047,7 @@ DATA: 'DATA'; DATE: 'DATE'; DAY: 'DAY'; DEALLOCATE: 'DEALLOCATE'; +DECLARE: 'DECLARE'; DEFAULT: 'DEFAULT'; DEFINE: 'DEFINE'; DEFINER: 'DEFINER'; @@ -980,12 +1056,15 @@ DENY: 'DENY'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; DESCRIPTOR: 'DESCRIPTOR'; +DETERMINISTIC: 'DETERMINISTIC'; DISTINCT: 'DISTINCT'; DISTRIBUTED: 'DISTRIBUTED'; +DO: 'DO'; DOUBLE: 'DOUBLE'; DROP: 'DROP'; ELSE: 'ELSE'; EMPTY: 'EMPTY'; +ELSEIF: 'ELSEIF'; ENCODING: 'ENCODING'; END: 'END'; ERROR: 'ERROR'; @@ -1006,6 +1085,7 @@ FOR: 'FOR'; FORMAT: 'FORMAT'; FROM: 'FROM'; FULL: 'FULL'; +FUNCTION: 'FUNCTION'; FUNCTIONS: 'FUNCTIONS'; GRACE: 'GRACE'; GRANT: 'GRANT'; @@ -1033,6 +1113,7 @@ INVOKER: 'INVOKER'; IO: 'IO'; IS: 'IS'; ISOLATION: 'ISOLATION'; +ITERATE: 'ITERATE'; JOIN: 'JOIN'; JSON: 'JSON'; JSON_ARRAY: 'JSON_ARRAY'; @@ -1044,9 +1125,11 @@ JSON_VALUE: 'JSON_VALUE'; KEEP: 'KEEP'; KEY: 'KEY'; KEYS: 'KEYS'; +LANGUAGE: 'LANGUAGE'; LAST: 'LAST'; LATERAL: 'LATERAL'; LEADING: 'LEADING'; +LEAVE: 'LEAVE'; LEFT: 'LEFT'; LEVEL: 'LEVEL'; LIKE: 'LIKE'; @@ -1056,6 +1139,7 @@ LOCAL: 'LOCAL'; LOCALTIME: 'LOCALTIME'; LOCALTIMESTAMP: 'LOCALTIMESTAMP'; LOGICAL: 'LOGICAL'; +LOOP: 'LOOP'; MAP: 'MAP'; MATCH: 'MATCH'; MATCHED: 'MATCHED'; @@ -1118,12 +1202,15 @@ READ: 'READ'; RECURSIVE: 'RECURSIVE'; REFRESH: 'REFRESH'; RENAME: 'RENAME'; +REPEAT: 'REPEAT'; REPEATABLE: 'REPEATABLE'; REPLACE: 'REPLACE'; RESET: 'RESET'; RESPECT: 'RESPECT'; RESTRICT: 'RESTRICT'; +RETURN: 'RETURN'; RETURNING: 'RETURNING'; +RETURNS: 'RETURNS'; REVOKE: 'REVOKE'; RIGHT: 'RIGHT'; ROLE: 'ROLE'; @@ -1177,6 +1264,7 @@ UNIQUE: 'UNIQUE'; UNKNOWN: 'UNKNOWN'; UNMATCHED: 'UNMATCHED'; UNNEST: 'UNNEST'; +UNTIL: 'UNTIL'; UPDATE: 'UPDATE'; USE: 'USE'; USER: 'USER'; @@ -1192,6 +1280,7 @@ VERSION: 'VERSION'; VIEW: 'VIEW'; WHEN: 'WHEN'; WHERE: 'WHERE'; +WHILE: 'WHILE'; WINDOW: 'WINDOW'; WITH: 'WITH'; WITHIN: 'WITHIN'; @@ -1216,6 +1305,7 @@ SLASH: '/'; PERCENT: '%'; CONCAT: '||'; QUESTION_MARK: '?'; +SEMICOLON: ';'; STRING : '\'' ( ~'\'' | '\'\'' )* '\'' diff --git a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java index fb0880b300ae..5249453a081e 100644 --- a/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java +++ b/core/trino-grammar/src/test/java/io/trino/grammar/sql/TestSqlKeywords.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableSet; import org.junit.jupiter.api.Test; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static org.assertj.core.api.Assertions.assertThat; public class TestSqlKeywords @@ -23,7 +24,7 @@ public class TestSqlKeywords @Test public void test() { - assertThat(SqlKeywords.sqlKeywords()) + assertThat(SqlKeywords.sqlKeywords().stream().sorted().collect(toImmutableSet())) .isEqualTo(ImmutableSet.of( "ABSENT", "ADD", @@ -39,11 +40,13 @@ public void test() "ASC", "AT", "AUTHORIZATION", + "BEGIN", "BERNOULLI", "BETWEEN", "BOTH", "BY", "CALL", + "CALLED", "CASCADE", "CASE", "CAST", @@ -74,6 +77,7 @@ public void test() "DATE", "DAY", "DEALLOCATE", + "DECLARE", "DEFAULT", "DEFINE", "DEFINER", @@ -82,11 +86,14 @@ public void test() "DESC", "DESCRIBE", "DESCRIPTOR", + "DETERMINISTIC", "DISTINCT", "DISTRIBUTED", + "DO", "DOUBLE", "DROP", "ELSE", + "ELSEIF", "EMPTY", "ENCODING", "END", @@ -108,6 +115,7 @@ public void test() "FORMAT", "FROM", "FULL", + "FUNCTION", "FUNCTIONS", "GRACE", "GRANT", @@ -135,6 +143,7 @@ public void test() "IO", "IS", "ISOLATION", + "ITERATE", "JOIN", "JSON", "JSON_ARRAY", @@ -146,9 +155,11 @@ public void test() "KEEP", "KEY", "KEYS", + "LANGUAGE", "LAST", "LATERAL", "LEADING", + "LEAVE", "LEFT", "LEVEL", "LIKE", @@ -158,6 +169,7 @@ public void test() "LOCALTIME", "LOCALTIMESTAMP", "LOGICAL", + "LOOP", "MAP", "MATCH", "MATCHED", @@ -220,12 +232,15 @@ public void test() "RECURSIVE", "REFRESH", "RENAME", + "REPEAT", "REPEATABLE", "REPLACE", "RESET", "RESPECT", "RESTRICT", + "RETURN", "RETURNING", + "RETURNS", "REVOKE", "RIGHT", "ROLE", @@ -280,6 +295,7 @@ public void test() "UNKNOWN", "UNMATCHED", "UNNEST", + "UNTIL", "UPDATE", "USE", "USER", @@ -292,6 +308,7 @@ public void test() "VIEW", "WHEN", "WHERE", + "WHILE", "WINDOW", "WITH", "WITHIN", diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 85ecd972178b..d015fa71e734 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -648,6 +648,13 @@ public void registerTable( columnMaskScopes.isEmpty())); } + public Set getResolvedFunctions() + { + return resolvedFunctions.values().stream() + .map(RoutineEntry::getFunction) + .collect(toImmutableSet()); + } + public ResolvedFunction getResolvedFunction(Expression node) { return resolvedFunctions.get(NodeRef.of(node)).getFunction(); @@ -680,6 +687,11 @@ public boolean isColumnReference(Expression expression) return columnReferences.containsKey(NodeRef.of(expression)); } + public void addType(Expression expression, Type type) + { + this.types.put(NodeRef.of(expression), type); + } + public void addTypes(Map, Type> types) { this.types.putAll(types); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index afdcfa920950..e0753191590a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1517,6 +1517,8 @@ protected Scope visitExplainAnalyze(ExplainAnalyze node, Optional scope) @Override protected Scope visitQuery(Query node, Optional scope) { + verify(node.getFunctions().isEmpty(), "Inline functions not yet supported"); + Scope withScope = analyzeWith(node, scope); Scope queryBodyScope = process(node.getQueryBody(), withScope); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java index 2cb355d0523a..55b0cc4ab935 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java @@ -359,7 +359,7 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference }; } - static class CompiledLambda + public static class CompiledLambda { // lambda method information private final Handle lambdaAsmHandle; diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java index 5c58fdf2045f..13845182a816 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java @@ -51,7 +51,7 @@ public class RowExpressionCompiler private final FunctionManager functionManager; private final Map compiledLambdaMap; - RowExpressionCompiler( + public RowExpressionCompiler( CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpressionVisitor fieldReferenceCompiler, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index cd8981cc43e0..edf827db7fa2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -909,7 +909,7 @@ private RelationPlanner getRelationPlanner(Analysis analysis) return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), plannerContext, Optional.empty(), session, ImmutableMap.of()); } - private static Map, Symbol> buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) + public static Map, Symbol> buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) { Map allocations = new HashMap<>(); Map, Symbol> result = new LinkedHashMap<>(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 55cc935ed77a..9e5ec42b8aa4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -117,7 +117,7 @@ *
  • AST expressions contain Identifiers, while IR expressions contain SymbolReferences
  • *
  • FunctionCalls in AST expressions are SQL function names. In IR expressions, they contain an encoded name representing a resolved function
  • */ -class TranslationMap +public class TranslationMap { // all expressions are rewritten in terms of fields declared by this relation plan private final Scope scope; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/SugarFreeChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/SugarFreeChecker.java index bd5fd4f3056e..f39495c6b0d4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/SugarFreeChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/SugarFreeChecker.java @@ -57,7 +57,7 @@ public void validate(PlanNode planNode, ExpressionExtractor.forEachExpression(planNode, SugarFreeChecker::validate); } - private static void validate(Expression expression) + public static void validate(Expression expression) { VISITOR.process(expression, null); } diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index 2fea1165a359..8fd3e2efd5b8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -151,7 +151,7 @@ public static RowExpression translate( return result; } - private static class Visitor + public static class Visitor extends AstVisitor { private final Metadata metadata; @@ -159,7 +159,7 @@ private static class Visitor private final Map layout; private final StandardFunctionResolution standardFunctionResolution; - private Visitor( + protected Visitor( Metadata metadata, Map, Type> types, Map layout) diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalysis.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalysis.java new file mode 100644 index 000000000000..2c27383efd95 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalysis.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import com.google.common.collect.ImmutableMap; +import io.trino.spi.type.Type; +import io.trino.sql.analyzer.Analysis; + +import java.util.Map; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record SqlRoutineAnalysis( + String name, + Map arguments, + Type returnType, + boolean calledOnNull, + boolean deterministic, + Optional comment, + Analysis analysis) +{ + public SqlRoutineAnalysis + { + requireNonNull(name, "name is null"); + arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + requireNonNull(returnType, "returnType is null"); + requireNonNull(comment, "comment is null"); + requireNonNull(analysis, "analysis is null"); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalyzer.java new file mode 100644 index 000000000000..ab6280785dae --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalyzer.java @@ -0,0 +1,594 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.ResolvedFunction; +import io.trino.security.AccessControl; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeNotFoundException; +import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.CorrelationSupport; +import io.trino.sql.analyzer.ExpressionAnalyzer; +import io.trino.sql.analyzer.Field; +import io.trino.sql.analyzer.QueryType; +import io.trino.sql.analyzer.RelationId; +import io.trino.sql.analyzer.RelationType; +import io.trino.sql.analyzer.Scope; +import io.trino.sql.analyzer.TypeSignatureTranslator; +import io.trino.sql.tree.AssignmentStatement; +import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.CaseStatement; +import io.trino.sql.tree.CaseStatementWhenClause; +import io.trino.sql.tree.CommentCharacteristic; +import io.trino.sql.tree.CompoundStatement; +import io.trino.sql.tree.ControlStatement; +import io.trino.sql.tree.DataType; +import io.trino.sql.tree.DeterministicCharacteristic; +import io.trino.sql.tree.ElseClause; +import io.trino.sql.tree.ElseIfClause; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.IfStatement; +import io.trino.sql.tree.IterateStatement; +import io.trino.sql.tree.LanguageCharacteristic; +import io.trino.sql.tree.LeaveStatement; +import io.trino.sql.tree.LoopStatement; +import io.trino.sql.tree.Node; +import io.trino.sql.tree.NullInputCharacteristic; +import io.trino.sql.tree.ParameterDeclaration; +import io.trino.sql.tree.RepeatStatement; +import io.trino.sql.tree.ReturnStatement; +import io.trino.sql.tree.ReturnsClause; +import io.trino.sql.tree.SecurityCharacteristic; +import io.trino.sql.tree.SecurityCharacteristic.Security; +import io.trino.sql.tree.VariableDeclaration; +import io.trino.sql.tree.WhileStatement; +import io.trino.type.TypeCoercion; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getLast; +import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static io.trino.spi.StandardErrorCode.MISSING_RETURN; +import static io.trino.spi.StandardErrorCode.NOT_FOUND; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.StandardErrorCode.SYNTAX_ERROR; +import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.analyzer.SemanticExceptions.semanticException; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; +import static java.lang.String.format; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; +import static java.util.function.Predicate.not; + +public class SqlRoutineAnalyzer +{ + private final PlannerContext plannerContext; + private final WarningCollector warningCollector; + + public SqlRoutineAnalyzer(PlannerContext plannerContext, WarningCollector warningCollector) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); + } + + public static FunctionMetadata extractFunctionMetadata(FunctionId functionId, FunctionSpecification function) + { + validateLanguage(function); + validateReturn(function); + + String functionName = getFunctionName(function); + Signature.Builder signatureBuilder = Signature.builder() + .returnType(toTypeSignature(function.getReturnsClause().getReturnType())); + + validateArguments(function); + function.getParameters().stream() + .map(ParameterDeclaration::getType) + .map(TypeSignatureTranslator::toTypeSignature) + .forEach(signatureBuilder::argumentType); + Signature signature = signatureBuilder.build(); + + FunctionMetadata.Builder builder = FunctionMetadata.scalarBuilder(functionName) + .functionId(functionId) + .signature(signature) + .nullable() + .argumentNullability(nCopies(signature.getArgumentTypes().size(), isCalledOnNull(function))); + + getComment(function) + .filter(not(String::isBlank)) + .ifPresentOrElse(builder::description, builder::noDescription); + + if (!getDeterministic(function).orElse(true)) { + builder.nondeterministic(); + } + + validateSecurity(function); + + return builder.build(); + } + + public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl, FunctionSpecification function) + { + String functionName = getFunctionName(function); + + validateLanguage(function); + + boolean calledOnNull = isCalledOnNull(function); + Optional comment = getComment(function); + validateSecurity(function); + + ReturnsClause returnsClause = function.getReturnsClause(); + Type returnType = getType(returnsClause, returnsClause.getReturnType()); + + Map arguments = getArguments(function); + + validateReturn(function); + + StatementVisitor visitor = new StatementVisitor(session, accessControl, returnType); + visitor.process(function.getStatement(), new Context(arguments, Set.of())); + + Analysis analysis = visitor.getAnalysis(); + + boolean actuallyDeterministic = analysis.getResolvedFunctions().stream().allMatch(ResolvedFunction::isDeterministic); + + boolean declaredDeterministic = getDeterministic(function).orElse(true); + if (!declaredDeterministic && actuallyDeterministic) { + throw semanticException(INVALID_ARGUMENTS, function, "Deterministic function declared NOT DETERMINISTIC"); + } + if (declaredDeterministic && !actuallyDeterministic) { + throw semanticException(INVALID_ARGUMENTS, function, "Non-deterministic function declared DETERMINISTIC"); + } + + return new SqlRoutineAnalysis( + functionName, + arguments, + returnType, + calledOnNull, + actuallyDeterministic, + comment, + visitor.getAnalysis()); + } + + private static String getFunctionName(FunctionSpecification function) + { + String name = function.getName().getSuffix(); + if (name.contains("@") || name.contains("$")) { + throw semanticException(NOT_SUPPORTED, function, "Function name cannot contain '@' or '$'"); + } + return name; + } + + private Type getType(Node node, DataType type) + { + try { + return plannerContext.getTypeManager().getType(toTypeSignature(type)); + } + catch (TypeNotFoundException e) { + throw semanticException(TYPE_MISMATCH, node, "Unknown type: " + type); + } + } + + private Map getArguments(FunctionSpecification function) + { + validateArguments(function); + + Map arguments = new LinkedHashMap<>(); + for (ParameterDeclaration parameter : function.getParameters()) { + arguments.put( + identifierValue(parameter.getName().orElseThrow()), + getType(parameter, parameter.getType())); + } + return arguments; + } + + private static void validateArguments(FunctionSpecification function) + { + Set argumentNames = new LinkedHashSet<>(); + for (ParameterDeclaration parameter : function.getParameters()) { + if (parameter.getName().isEmpty()) { + throw semanticException(INVALID_ARGUMENTS, parameter, "Function parameters must have a name"); + } + String name = identifierValue(parameter.getName().get()); + if (!argumentNames.add(name)) { + throw semanticException(INVALID_ARGUMENTS, parameter, "Duplicate function parameter name: " + name); + } + } + } + + private static Optional getLanguage(FunctionSpecification function) + { + List language = function.getRoutineCharacteristics().stream() + .filter(LanguageCharacteristic.class::isInstance) + .map(LanguageCharacteristic.class::cast) + .collect(toImmutableList()); + + if (language.size() > 1) { + throw semanticException(SYNTAX_ERROR, function, "Multiple language clauses specified"); + } + + return language.stream() + .map(LanguageCharacteristic::getLanguage) + .map(Identifier::getValue) + .findAny(); + } + + private static void validateLanguage(FunctionSpecification function) + { + Optional language = getLanguage(function); + if (language.isPresent() && !language.get().equalsIgnoreCase("sql")) { + throw semanticException(NOT_SUPPORTED, function, "Unsupported language: %s", language.get()); + } + } + + private static Optional getDeterministic(FunctionSpecification function) + { + List deterministic = function.getRoutineCharacteristics().stream() + .filter(DeterministicCharacteristic.class::isInstance) + .map(DeterministicCharacteristic.class::cast) + .collect(toImmutableList()); + + if (deterministic.size() > 1) { + throw semanticException(SYNTAX_ERROR, function, "Multiple deterministic clauses specified"); + } + + return deterministic.stream() + .map(DeterministicCharacteristic::isDeterministic) + .findAny(); + } + + private static boolean isCalledOnNull(FunctionSpecification function) + { + List nullInput = function.getRoutineCharacteristics().stream() + .filter(NullInputCharacteristic.class::isInstance) + .map(NullInputCharacteristic.class::cast) + .collect(toImmutableList()); + + if (nullInput.size() > 1) { + throw semanticException(SYNTAX_ERROR, function, "Multiple null-call clauses specified"); + } + + return nullInput.stream() + .map(NullInputCharacteristic::isCalledOnNull) + .findAny() + .orElse(true); + } + + public static boolean isRunAsInvoker(FunctionSpecification function) + { + List security = function.getRoutineCharacteristics().stream() + .filter(SecurityCharacteristic.class::isInstance) + .map(SecurityCharacteristic.class::cast) + .collect(toImmutableList()); + + if (security.size() > 1) { + throw semanticException(SYNTAX_ERROR, function, "Multiple security clauses specified"); + } + + return security.stream() + .map(SecurityCharacteristic::getSecurity) + .map(Security.INVOKER::equals) + .findAny() + .orElse(false); + } + + private static void validateSecurity(FunctionSpecification function) + { + isRunAsInvoker(function); + } + + private static Optional getComment(FunctionSpecification function) + { + List comment = function.getRoutineCharacteristics().stream() + .filter(CommentCharacteristic.class::isInstance) + .map(CommentCharacteristic.class::cast) + .collect(toImmutableList()); + + if (comment.size() > 1) { + throw semanticException(SYNTAX_ERROR, function, "Multiple comment clauses specified"); + } + + return comment.stream() + .map(CommentCharacteristic::getComment) + .findAny(); + } + + private static void validateReturn(FunctionSpecification function) + { + ControlStatement statement = function.getStatement(); + if (statement instanceof ReturnStatement) { + return; + } + + checkArgument(statement instanceof CompoundStatement, "invalid function statement: %s", statement); + CompoundStatement body = (CompoundStatement) statement; + if (!(getLast(body.getStatements(), null) instanceof ReturnStatement)) { + throw semanticException(MISSING_RETURN, body, "Function must end in a RETURN statement"); + } + } + + private class StatementVisitor + extends AstVisitor + { + private final Session session; + private final AccessControl accessControl; + private final Type returnType; + + private final Analysis analysis = new Analysis(null, ImmutableMap.of(), QueryType.OTHERS); + private final TypeCoercion typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); + + public StatementVisitor(Session session, AccessControl accessControl, Type returnType) + { + this.session = requireNonNull(session, "session is null"); + this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.returnType = requireNonNull(returnType, "returnType is null"); + } + + public Analysis getAnalysis() + { + return analysis; + } + + @Override + protected Void visitNode(Node node, Context context) + { + throw new UnsupportedOperationException("Analysis not yet implemented: " + node); + } + + @Override + protected Void visitCompoundStatement(CompoundStatement node, Context context) + { + Context newContext = context.newScope(); + + for (VariableDeclaration declaration : node.getVariableDeclarations()) { + Type type = getType(declaration, declaration.getType()); + analysis.addType(declaration.getType(), type); + declaration.getDefaultValue().ifPresent(value -> + analyzeExpression(newContext, value, type, "Value of DEFAULT")); + + for (Identifier name : declaration.getNames()) { + if (newContext.variables().put(identifierValue(name), type) != null) { + throw semanticException(ALREADY_EXISTS, name, "Variable already declared in this scope: %s", name); + } + } + } + + analyzeNodes(newContext, node.getStatements()); + + return null; + } + + @Override + protected Void visitIfStatement(IfStatement node, Context context) + { + analyzeExpression(context, node.getExpression(), BOOLEAN, "Condition of IF statement"); + analyzeNodes(context, node.getStatements()); + analyzeNodes(context, node.getElseIfClauses()); + node.getElseClause().ifPresent(statement -> process(statement, context)); + return null; + } + + @Override + protected Void visitElseIfClause(ElseIfClause node, Context context) + { + analyzeExpression(context, node.getExpression(), BOOLEAN, "Condition of ELSEIF clause"); + analyzeNodes(context, node.getStatements()); + return null; + } + + @Override + protected Void visitElseClause(ElseClause node, Context context) + { + analyzeNodes(context, node.getStatements()); + return null; + } + + @Override + protected Void visitCaseStatement(CaseStatement node, Context context) + { + // when clause condition + if (node.getExpression().isPresent()) { + Type valueType = analyzeExpression(context, node.getExpression().get()); + for (CaseStatementWhenClause whenClause : node.getWhenClauses()) { + Type whenType = analyzeExpression(context, whenClause.getExpression()); + Optional superType = typeCoercion.getCommonSuperType(valueType, whenType); + if (superType.isEmpty()) { + throw semanticException(TYPE_MISMATCH, whenClause.getExpression(), "WHEN clause value must evaluate to CASE value type %s (actual: %s)", valueType, whenType); + } + if (!whenType.equals(superType.get())) { + addCoercion(whenClause.getExpression(), whenType, superType.get()); + } + } + } + else { + for (CaseStatementWhenClause whenClause : node.getWhenClauses()) { + analyzeExpression(context, whenClause.getExpression(), BOOLEAN, "Condition of WHEN clause"); + } + } + + // when clause body + for (CaseStatementWhenClause whenClause : node.getWhenClauses()) { + analyzeNodes(context, whenClause.getStatements()); + } + + // else clause body + if (node.getElseClause().isPresent()) { + process(node.getElseClause().get(), context); + } + return null; + } + + @Override + protected Void visitWhileStatement(WhileStatement node, Context context) + { + Context newContext = context.newScope(); + node.getLabel().ifPresent(name -> defineLabel(newContext, name)); + analyzeExpression(newContext, node.getExpression(), BOOLEAN, "Condition of WHILE statement"); + analyzeNodes(newContext, node.getStatements()); + return null; + } + + @Override + protected Void visitRepeatStatement(RepeatStatement node, Context context) + { + Context newContext = context.newScope(); + node.getLabel().ifPresent(name -> defineLabel(newContext, name)); + analyzeExpression(newContext, node.getCondition(), BOOLEAN, "Condition of REPEAT statement"); + analyzeNodes(newContext, node.getStatements()); + return null; + } + + @Override + protected Void visitLoopStatement(LoopStatement node, Context context) + { + Context newContext = context.newScope(); + node.getLabel().ifPresent(name -> defineLabel(newContext, name)); + analyzeNodes(newContext, node.getStatements()); + return null; + } + + @Override + protected Void visitReturnStatement(ReturnStatement node, Context context) + { + analyzeExpression(context, node.getValue(), returnType, "Value of RETURN"); + return null; + } + + @Override + protected Void visitAssignmentStatement(AssignmentStatement node, Context context) + { + Identifier name = node.getTarget(); + Type targetType = context.variables().get(identifierValue(name)); + if (targetType == null) { + throw semanticException(NOT_FOUND, name, "Variable cannot be resolved: %s", name); + } + analyzeExpression(context, node.getValue(), targetType, format("Value of SET '%s'", name)); + return null; + } + + @Override + protected Void visitIterateStatement(IterateStatement node, Context context) + { + verifyLabelExists(context, node.getLabel()); + return null; + } + + @Override + protected Void visitLeaveStatement(LeaveStatement node, Context context) + { + verifyLabelExists(context, node.getLabel()); + return null; + } + + private void analyzeExpression(Context context, Expression expression, Type expectedType, String message) + { + Type actualType = analyzeExpression(context, expression); + if (actualType.equals(expectedType)) { + return; + } + if (!typeCoercion.canCoerce(actualType, expectedType)) { + throw semanticException(TYPE_MISMATCH, expression, message + " must evaluate to %s (actual: %s)", expectedType, actualType); + } + + addCoercion(expression, actualType, expectedType); + } + + private Type analyzeExpression(Context context, Expression expression) + { + List fields = context.variables().entrySet().stream() + .map(entry -> Field.newUnqualified(entry.getKey(), entry.getValue())) + .collect(toImmutableList()); + + Scope scope = Scope.builder() + .withRelationType(RelationId.of(expression), new RelationType(fields)) + .build(); + + ExpressionAnalyzer.analyzeExpressionWithoutSubqueries( + session, + plannerContext, + accessControl, + scope, + analysis, + expression, + NOT_SUPPORTED, + "Queries are not allowed in functions", + warningCollector, + CorrelationSupport.DISALLOWED); + + return analysis.getType(expression); + } + + private void addCoercion(Expression expression, Type actualType, Type expectedType) + { + analysis.addCoercion(expression, expectedType, typeCoercion.isTypeOnlyCoercion(actualType, expectedType)); + } + + private void analyzeNodes(Context context, List statements) + { + for (Node statement : statements) { + process(statement, context); + } + } + + private static void defineLabel(Context context, Identifier name) + { + if (!context.labels().add(identifierValue(name))) { + throw semanticException(ALREADY_EXISTS, name, "Label already declared in this scope: %s", name); + } + } + + private static void verifyLabelExists(Context context, Identifier name) + { + if (!context.labels().contains(identifierValue(name))) { + throw semanticException(NOT_FOUND, name, "Label not defined: %s", name); + } + } + } + + private record Context(Map variables, Set labels) + { + private Context + { + variables = new LinkedHashMap<>(variables); + labels = new LinkedHashSet<>(labels); + } + + public Context newScope() + { + return new Context(variables, labels); + } + } + + private static String identifierValue(Identifier name) + { + // TODO: this should use getCanonicalValue() + return name.getValue(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java new file mode 100644 index 000000000000..c6bb848a977c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java @@ -0,0 +1,591 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.DynamicClassLoader; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.ParameterizedType; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.DoWhileLoop; +import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.control.WhileLoop; +import io.airlift.bytecode.instruction.LabelNode; +import io.trino.metadata.FunctionManager; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionAdapter; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.Type; +import io.trino.sql.gen.CachedInstanceBinder; +import io.trino.sql.gen.CallSiteBinder; +import io.trino.sql.gen.LambdaBytecodeGenerator.CompiledLambda; +import io.trino.sql.gen.RowExpressionCompiler; +import io.trino.sql.relational.CallExpression; +import io.trino.sql.relational.ConstantExpression; +import io.trino.sql.relational.InputReferenceExpression; +import io.trino.sql.relational.LambdaDefinitionExpression; +import io.trino.sql.relational.RowExpression; +import io.trino.sql.relational.RowExpressionVisitor; +import io.trino.sql.relational.SpecialForm; +import io.trino.sql.relational.VariableReferenceExpression; +import io.trino.sql.routine.ir.DefaultIrNodeVisitor; +import io.trino.sql.routine.ir.IrBlock; +import io.trino.sql.routine.ir.IrBreak; +import io.trino.sql.routine.ir.IrContinue; +import io.trino.sql.routine.ir.IrIf; +import io.trino.sql.routine.ir.IrLabel; +import io.trino.sql.routine.ir.IrLoop; +import io.trino.sql.routine.ir.IrNode; +import io.trino.sql.routine.ir.IrNodeVisitor; +import io.trino.sql.routine.ir.IrRepeat; +import io.trino.sql.routine.ir.IrReturn; +import io.trino.sql.routine.ir.IrRoutine; +import io.trino.sql.routine.ir.IrSet; +import io.trino.sql.routine.ir.IrStatement; +import io.trino.sql.routine.ir.IrVariable; +import io.trino.sql.routine.ir.IrWhile; +import io.trino.util.Reflection; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static com.google.common.primitives.Primitives.wrap; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; +import static io.airlift.bytecode.expression.BytecodeExpressions.greaterThanOrEqual; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; +import static io.airlift.bytecode.instruction.Constant.loadBoolean; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; +import static io.trino.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; +import static io.trino.sql.gen.LambdaBytecodeGenerator.preGenerateLambdaExpression; +import static io.trino.sql.gen.LambdaExpressionExtractor.extractLambdaExpressions; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; +import static io.trino.util.Reflection.constructorMethodHandle; +import static java.util.Arrays.stream; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public final class SqlRoutineCompiler +{ + private final FunctionManager functionManager; + + public SqlRoutineCompiler(FunctionManager functionManager) + { + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + } + + public SpecializedSqlScalarFunction compile(IrRoutine routine) + { + Type returnType = routine.returnType(); + List parameterTypes = routine.parameters().stream() + .map(IrVariable::type) + .collect(toImmutableList()); + + InvocationConvention callingConvention = new InvocationConvention( + // todo this should be based on the declared nullability of the parameters + Collections.nCopies(parameterTypes.size(), BOXED_NULLABLE), + NULLABLE_RETURN, + true, + true); + + Class clazz = compileClass(routine); + + MethodHandle handle = stream(clazz.getMethods()) + .filter(method -> method.getName().equals("run")) + .map(Reflection::methodHandle) + .collect(onlyElement()); + + MethodHandle instanceFactory = constructorMethodHandle(clazz); + + MethodHandle objectHandle = handle.asType(handle.type().changeParameterType(0, Object.class)); + MethodHandle objectInstanceFactory = instanceFactory.asType(instanceFactory.type().changeReturnType(Object.class)); + + return invocationConvention -> { + MethodHandle adapted = ScalarFunctionAdapter.adapt( + objectHandle, + returnType, + parameterTypes, + callingConvention, + invocationConvention); + return ScalarFunctionImplementation.builder() + .methodHandle(adapted) + .instanceFactory(objectInstanceFactory) + .build(); + }; + } + + @VisibleForTesting + public Class compileClass(IrRoutine routine) + { + ClassDefinition classDefinition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("SqlRoutine"), + type(Object.class)); + + CallSiteBinder callSiteBinder = new CallSiteBinder(); + CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); + + Map compiledLambdaMap = generateMethodsForLambda(classDefinition, cachedInstanceBinder, routine); + + generateRunMethod(classDefinition, cachedInstanceBinder, compiledLambdaMap, routine); + + declareConstructor(classDefinition, cachedInstanceBinder); + + return defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), new DynamicClassLoader(getClass().getClassLoader())); + } + + private Map generateMethodsForLambda( + ClassDefinition containerClassDefinition, + CachedInstanceBinder cachedInstanceBinder, + IrNode node) + { + Set lambdaExpressions = extractLambda(node); + ImmutableMap.Builder compiledLambdaMap = ImmutableMap.builder(); + int counter = 0; + for (LambdaDefinitionExpression lambdaExpression : lambdaExpressions) { + CompiledLambda compiledLambda = preGenerateLambdaExpression( + lambdaExpression, + "lambda_" + counter, + containerClassDefinition, + compiledLambdaMap.buildOrThrow(), + cachedInstanceBinder.getCallSiteBinder(), + cachedInstanceBinder, + functionManager); + compiledLambdaMap.put(lambdaExpression, compiledLambda); + counter++; + } + return compiledLambdaMap.buildOrThrow(); + } + + private void generateRunMethod( + ClassDefinition classDefinition, + CachedInstanceBinder cachedInstanceBinder, + Map compiledLambdaMap, + IrRoutine routine) + { + ImmutableList.Builder parameterBuilder = ImmutableList.builder(); + parameterBuilder.add(arg("session", ConnectorSession.class)); + for (IrVariable sqlVariable : routine.parameters()) { + parameterBuilder.add(arg(name(sqlVariable), compilerType(sqlVariable.type()))); + } + + MethodDefinition method = classDefinition.declareMethod( + a(PUBLIC), + "run", + compilerType(routine.returnType()), + parameterBuilder.build()); + + Scope scope = method.getScope(); + + scope.declareVariable(boolean.class, "wasNull"); + + Map variables = VariableExtractor.extract(routine).stream().distinct() + .collect(toImmutableMap(identity(), variable -> getOrDeclareVariable(scope, variable))); + + BytecodeVisitor visitor = new BytecodeVisitor(cachedInstanceBinder, compiledLambdaMap, variables); + method.getBody().append(visitor.process(routine, scope)); + } + + private static BytecodeNode throwIfInterrupted() + { + return new IfStatement() + .condition(invokeStatic(Thread.class, "currentThread", Thread.class) + .invoke("isInterrupted", boolean.class)) + .ifTrue(new BytecodeBlock() + .append(newInstance(RuntimeException.class, constantString("Thread interrupted"))) + .throwObject()); + } + + private static void declareConstructor(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder) + { + MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); + BytecodeBlock body = constructorDefinition.getBody(); + body.append(constructorDefinition.getThis()) + .invokeConstructor(Object.class); + cachedInstanceBinder.generateInitializations(constructorDefinition.getThis(), body); + body.ret(); + } + + private static Variable getOrDeclareVariable(Scope scope, IrVariable variable) + { + return getOrDeclareVariable(scope, compilerType(variable.type()), name(variable)); + } + + private static Variable getOrDeclareVariable(Scope scope, ParameterizedType type, String name) + { + try { + return scope.getVariable(name); + } + catch (IllegalArgumentException e) { + return scope.declareVariable(type, name); + } + } + + private static ParameterizedType compilerType(Type type) + { + return type(wrap(type.getJavaType())); + } + + private static String name(IrVariable variable) + { + return name(variable.field()); + } + + private static String name(int field) + { + return "v" + field; + } + + private class BytecodeVisitor + implements IrNodeVisitor + { + private final CachedInstanceBinder cachedInstanceBinder; + private final Map compiledLambdaMap; + private final Map variables; + + private final Map continueLabels = new IdentityHashMap<>(); + private final Map breakLabels = new IdentityHashMap<>(); + + public BytecodeVisitor( + CachedInstanceBinder cachedInstanceBinder, + Map compiledLambdaMap, + Map variables) + { + this.cachedInstanceBinder = requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null"); + this.compiledLambdaMap = requireNonNull(compiledLambdaMap, "compiledLambdaMap is null"); + this.variables = requireNonNull(variables, "variables is null"); + } + + @Override + public BytecodeNode visitNode(IrNode node, Scope context) + { + throw new VerifyException("Unsupported node: " + node.getClass().getSimpleName()); + } + + @Override + public BytecodeNode visitRoutine(IrRoutine node, Scope scope) + { + return process(node.body(), scope); + } + + @Override + public BytecodeNode visitSet(IrSet node, Scope scope) + { + return new BytecodeBlock() + .append(compile(node.value(), scope)) + .putVariable(variables.get(node.target())); + } + + @Override + public BytecodeNode visitBlock(IrBlock node, Scope scope) + { + BytecodeBlock block = new BytecodeBlock(); + + for (IrVariable sqlVariable : node.variables()) { + block.append(compile(sqlVariable.defaultValue(), scope)) + .putVariable(variables.get(sqlVariable)); + } + + LabelNode continueLabel = new LabelNode("continue"); + LabelNode breakLabel = new LabelNode("break"); + + if (node.label().isPresent()) { + continueLabels.put(node.label().get(), continueLabel); + breakLabels.put(node.label().get(), breakLabel); + block.visitLabel(continueLabel); + } + + for (IrStatement statement : node.statements()) { + block.append(process(statement, scope)); + } + + if (node.label().isPresent()) { + block.visitLabel(breakLabel); + } + + return block; + } + + @Override + public BytecodeNode visitReturn(IrReturn node, Scope scope) + { + return new BytecodeBlock() + .append(compile(node.value(), scope)) + .ret(wrap(node.value().getType().getJavaType())); + } + + @Override + public BytecodeNode visitContinue(IrContinue node, Scope scope) + { + LabelNode label = continueLabels.get(node.target()); + verify(label != null, "continue target does not exist"); + return new BytecodeBlock() + .gotoLabel(label); + } + + @Override + public BytecodeNode visitBreak(IrBreak node, Scope scope) + { + LabelNode label = breakLabels.get(node.target()); + verify(label != null, "break target does not exist"); + return new BytecodeBlock() + .gotoLabel(label); + } + + @Override + public BytecodeNode visitIf(IrIf node, Scope scope) + { + IfStatement ifStatement = new IfStatement() + .condition(compileBoolean(node.condition(), scope)) + .ifTrue(process(node.ifTrue(), scope)); + + if (node.ifFalse().isPresent()) { + ifStatement.ifFalse(process(node.ifFalse().get(), scope)); + } + + return ifStatement; + } + + @Override + public BytecodeNode visitWhile(IrWhile node, Scope scope) + { + return compileLoop(scope, node.label(), interruption -> new WhileLoop() + .condition(compileBoolean(node.condition(), scope)) + .body(new BytecodeBlock() + .append(interruption) + .append(process(node.body(), scope)))); + } + + @Override + public BytecodeNode visitRepeat(IrRepeat node, Scope scope) + { + return compileLoop(scope, node.label(), interruption -> new DoWhileLoop() + .condition(not(compileBoolean(node.condition(), scope))) + .body(new BytecodeBlock() + .append(interruption) + .append(process(node.block(), scope)))); + } + + @Override + public BytecodeNode visitLoop(IrLoop node, Scope scope) + { + return compileLoop(scope, node.label(), interruption -> new WhileLoop() + .condition(loadBoolean(true)) + .body(new BytecodeBlock() + .append(interruption) + .append(process(node.block(), scope)))); + } + + private BytecodeNode compileLoop(Scope scope, Optional label, Function loop) + { + BytecodeBlock block = new BytecodeBlock(); + + Variable interruption = scope.createTempVariable(int.class); + block.putVariable(interruption, 0); + + BytecodeBlock interruptionBlock = new BytecodeBlock() + .append(interruption.increment()) + .append(new IfStatement() + .condition(greaterThanOrEqual(interruption, constantInt(1000))) + .ifTrue(new BytecodeBlock() + .append(interruption.set(constantInt(0))) + .append(throwIfInterrupted()))); + + LabelNode continueLabel = new LabelNode("continue"); + LabelNode breakLabel = new LabelNode("break"); + + if (label.isPresent()) { + continueLabels.put(label.get(), continueLabel); + breakLabels.put(label.get(), breakLabel); + block.visitLabel(continueLabel); + } + + block.append(loop.apply(interruptionBlock)); + + if (label.isPresent()) { + block.visitLabel(breakLabel); + } + + return block; + } + + private BytecodeNode compile(RowExpression expression, Scope scope) + { + if (expression instanceof InputReferenceExpression input) { + return scope.getVariable(name(input.getField())); + } + + RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler( + cachedInstanceBinder.getCallSiteBinder(), + cachedInstanceBinder, + FieldReferenceCompiler.INSTANCE, + functionManager, + compiledLambdaMap); + + return new BytecodeBlock() + .comment("boolean wasNull = false;") + .putVariable(scope.getVariable("wasNull"), expression.getType().getJavaType() == void.class) + .comment("expression: " + expression) + .append(rowExpressionCompiler.compile(expression, scope)) + .append(boxPrimitiveIfNecessary(scope, wrap(expression.getType().getJavaType()))); + } + + private BytecodeNode compileBoolean(RowExpression expression, Scope scope) + { + checkArgument(expression.getType().equals(BooleanType.BOOLEAN), "type must be boolean"); + + LabelNode notNull = new LabelNode("notNull"); + LabelNode done = new LabelNode("done"); + + return new BytecodeBlock() + .append(compile(expression, scope)) + .comment("if value is null, return false, otherwise unbox") + .dup() + .ifNotNullGoto(notNull) + .pop() + .push(false) + .gotoLabel(done) + .visitLabel(notNull) + .invokeVirtual(Boolean.class, "booleanValue", boolean.class) + .visitLabel(done); + } + + private static BytecodeNode not(BytecodeNode node) + { + LabelNode trueLabel = new LabelNode("true"); + LabelNode endLabel = new LabelNode("end"); + return new BytecodeBlock() + .append(node) + .comment("boolean not") + .ifTrueGoto(trueLabel) + .push(true) + .gotoLabel(endLabel) + .visitLabel(trueLabel) + .push(false) + .visitLabel(endLabel); + } + } + + private static Set extractLambda(IrNode node) + { + ImmutableSet.Builder expressions = ImmutableSet.builder(); + node.accept(new DefaultIrNodeVisitor() + { + @Override + public void visitRowExpression(RowExpression expression) + { + expressions.addAll(extractLambdaExpressions(expression)); + } + }, null); + return expressions.build(); + } + + private static class FieldReferenceCompiler + implements RowExpressionVisitor + { + public static final FieldReferenceCompiler INSTANCE = new FieldReferenceCompiler(); + + @Override + public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) + { + Class boxedType = wrap(node.getType().getJavaType()); + return new BytecodeBlock() + .append(scope.getVariable(name(node.getField()))) + .append(unboxPrimitiveIfNecessary(scope, boxedType)); + } + + @Override + public BytecodeNode visitCall(CallExpression call, Scope scope) + { + throw new UnsupportedOperationException(); + } + + @Override + public BytecodeNode visitSpecialForm(SpecialForm specialForm, Scope context) + { + throw new UnsupportedOperationException(); + } + + @Override + public BytecodeNode visitConstant(ConstantExpression literal, Scope scope) + { + throw new UnsupportedOperationException(); + } + + @Override + public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context) + { + throw new UnsupportedOperationException(); + } + + @Override + public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context) + { + throw new UnsupportedOperationException(); + } + } + + private static class VariableExtractor + extends DefaultIrNodeVisitor + { + private final List variables = new ArrayList<>(); + + @Override + public Void visitVariable(IrVariable node, Void context) + { + variables.add(node); + return null; + } + + public static List extract(IrNode node) + { + VariableExtractor extractor = new VariableExtractor(); + extractor.process(node, null); + return extractor.variables; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java new file mode 100644 index 000000000000..1b3518e3e2b3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java @@ -0,0 +1,465 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.ExpressionAnalyzer; +import io.trino.sql.analyzer.Field; +import io.trino.sql.analyzer.RelationId; +import io.trino.sql.analyzer.RelationType; +import io.trino.sql.analyzer.Scope; +import io.trino.sql.planner.ExpressionInterpreter; +import io.trino.sql.planner.LiteralEncoder; +import io.trino.sql.planner.NoOpSymbolResolver; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.TranslationMap; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter; +import io.trino.sql.planner.sanity.SugarFreeChecker; +import io.trino.sql.relational.RowExpression; +import io.trino.sql.relational.SqlToRowExpressionTranslator; +import io.trino.sql.relational.StandardFunctionResolution; +import io.trino.sql.relational.optimizer.ExpressionOptimizer; +import io.trino.sql.routine.ir.IrBlock; +import io.trino.sql.routine.ir.IrBreak; +import io.trino.sql.routine.ir.IrContinue; +import io.trino.sql.routine.ir.IrIf; +import io.trino.sql.routine.ir.IrLabel; +import io.trino.sql.routine.ir.IrLoop; +import io.trino.sql.routine.ir.IrRepeat; +import io.trino.sql.routine.ir.IrReturn; +import io.trino.sql.routine.ir.IrRoutine; +import io.trino.sql.routine.ir.IrSet; +import io.trino.sql.routine.ir.IrStatement; +import io.trino.sql.routine.ir.IrVariable; +import io.trino.sql.routine.ir.IrWhile; +import io.trino.sql.tree.AssignmentStatement; +import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.CaseStatement; +import io.trino.sql.tree.CaseStatementWhenClause; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.CompoundStatement; +import io.trino.sql.tree.ControlStatement; +import io.trino.sql.tree.ElseIfClause; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.IfStatement; +import io.trino.sql.tree.IterateStatement; +import io.trino.sql.tree.LambdaArgumentDeclaration; +import io.trino.sql.tree.LeaveStatement; +import io.trino.sql.tree.LoopStatement; +import io.trino.sql.tree.Node; +import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.RepeatStatement; +import io.trino.sql.tree.ReturnStatement; +import io.trino.sql.tree.SymbolReference; +import io.trino.sql.tree.VariableDeclaration; +import io.trino.sql.tree.WhileStatement; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.LogicalPlanner.buildLambdaDeclarationToSymbolMap; +import static io.trino.sql.relational.Expressions.call; +import static io.trino.sql.relational.Expressions.constantNull; +import static io.trino.sql.relational.Expressions.field; +import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; +import static java.util.Objects.requireNonNull; + +public final class SqlRoutinePlanner +{ + private final PlannerContext plannerContext; + private final WarningCollector warningCollector; + + public SqlRoutinePlanner(PlannerContext plannerContext, WarningCollector warningCollector) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); + } + + public IrRoutine planSqlFunction(Session session, FunctionSpecification function, SqlRoutineAnalysis routineAnalysis) + { + List allVariables = new ArrayList<>(); + Map scopeVariables = new LinkedHashMap<>(); + + ImmutableList.Builder parameters = ImmutableList.builder(); + routineAnalysis.arguments().forEach((name, type) -> { + IrVariable variable = new IrVariable(allVariables.size(), type, constantNull(type)); + allVariables.add(variable); + scopeVariables.put(name, variable); + parameters.add(variable); + }); + + Analysis analysis = routineAnalysis.analysis(); + StatementVisitor visitor = new StatementVisitor(session, allVariables, analysis); + IrStatement body = visitor.process(function.getStatement(), new Context(scopeVariables, Map.of())); + + return new IrRoutine(routineAnalysis.returnType(), parameters.build(), body); + } + + private class StatementVisitor + extends AstVisitor + { + private final Session session; + private final List allVariables; + private final Analysis analysis; + private final StandardFunctionResolution resolution; + + public StatementVisitor( + Session session, + List allVariables, + Analysis analysis) + { + this.session = requireNonNull(session, "session is null"); + this.resolution = new StandardFunctionResolution(plannerContext.getMetadata()); + this.allVariables = requireNonNull(allVariables, "allVariables is null"); + this.analysis = requireNonNull(analysis, "analysis is null"); + } + + @Override + protected IrStatement visitNode(Node node, Context context) + { + throw new UnsupportedOperationException("Not implemented: " + node); + } + + @Override + protected IrStatement visitCompoundStatement(CompoundStatement node, Context context) + { + Context newContext = context.newScope(); + + ImmutableList.Builder blockVariables = ImmutableList.builder(); + for (VariableDeclaration declaration : node.getVariableDeclarations()) { + Type type = analysis.getType(declaration.getType()); + RowExpression defaultValue = declaration.getDefaultValue() + .map(expression -> toRowExpression(newContext, expression)) + .orElse(constantNull(type)); + + for (Identifier name : declaration.getNames()) { + IrVariable variable = new IrVariable(allVariables.size(), type, defaultValue); + allVariables.add(variable); + verify(newContext.variables().put(identifierValue(name), variable) == null, "Variable already declared in scope: %s", name); + blockVariables.add(variable); + } + } + + List statements = node.getStatements().stream() + .map(statement -> process(statement, newContext)) + .collect(toImmutableList()); + + return new IrBlock(blockVariables.build(), statements); + } + + @Override + protected IrStatement visitIfStatement(IfStatement node, Context context) + { + IrStatement statement = null; + + List elseIfList = Lists.reverse(node.getElseIfClauses()); + for (int i = 0; i < elseIfList.size(); i++) { + ElseIfClause elseIf = elseIfList.get(i); + RowExpression condition = toRowExpression(context, elseIf.getExpression()); + IrStatement ifTrue = block(statements(elseIf.getStatements(), context)); + + Optional ifFalse = Optional.empty(); + if ((i == 0) && node.getElseClause().isPresent()) { + List elseList = node.getElseClause().get().getStatements(); + ifFalse = Optional.of(block(statements(elseList, context))); + } + else if (statement != null) { + ifFalse = Optional.of(statement); + } + + statement = new IrIf(condition, ifTrue, ifFalse); + } + + return new IrIf( + toRowExpression(context, node.getExpression()), + block(statements(node.getStatements(), context)), + Optional.ofNullable(statement)); + } + + @Override + protected IrStatement visitCaseStatement(CaseStatement node, Context context) + { + if (node.getExpression().isPresent()) { + RowExpression valueExpression = toRowExpression(context, node.getExpression().get()); + IrVariable valueVariable = new IrVariable(allVariables.size(), valueExpression.getType(), valueExpression); + + IrStatement statement = node.getElseClause() + .map(elseClause -> block(statements(elseClause.getStatements(), context))) + .orElseGet(() -> new IrBlock(ImmutableList.of(), ImmutableList.of())); + + for (CaseStatementWhenClause whenClause : Lists.reverse(node.getWhenClauses())) { + RowExpression conditionValue = toRowExpression(context, whenClause.getExpression()); + + RowExpression testValue = field(valueVariable.field(), valueVariable.type()); + if (!testValue.getType().equals(conditionValue.getType())) { + ResolvedFunction castFunction = plannerContext.getMetadata().getCoercion(testValue.getType(), conditionValue.getType()); + testValue = call(castFunction, testValue); + } + + ResolvedFunction equals = resolution.comparisonFunction(EQUAL, testValue.getType(), conditionValue.getType()); + RowExpression condition = call(equals, testValue, conditionValue); + + IrStatement ifTrue = block(statements(whenClause.getStatements(), context)); + statement = new IrIf(condition, ifTrue, Optional.of(statement)); + } + return new IrBlock(ImmutableList.of(valueVariable), ImmutableList.of(statement)); + } + + IrStatement statement = node.getElseClause() + .map(elseClause -> block(statements(elseClause.getStatements(), context))) + .orElseGet(() -> new IrBlock(ImmutableList.of(), ImmutableList.of())); + + for (CaseStatementWhenClause whenClause : Lists.reverse(node.getWhenClauses())) { + RowExpression condition = toRowExpression(context, whenClause.getExpression()); + IrStatement ifTrue = block(statements(whenClause.getStatements(), context)); + statement = new IrIf(condition, ifTrue, Optional.of(statement)); + } + + return statement; + } + + @Override + protected IrStatement visitWhileStatement(WhileStatement node, Context context) + { + Context newContext = context.newScope(); + Optional label = getSqlLabel(newContext, node.getLabel()); + RowExpression condition = toRowExpression(newContext, node.getExpression()); + List statements = statements(node.getStatements(), newContext); + return new IrWhile(label, condition, block(statements)); + } + + @Override + protected IrStatement visitRepeatStatement(RepeatStatement node, Context context) + { + Context newContext = context.newScope(); + Optional label = getSqlLabel(newContext, node.getLabel()); + RowExpression condition = toRowExpression(newContext, node.getCondition()); + List statements = statements(node.getStatements(), newContext); + return new IrRepeat(label, condition, block(statements)); + } + + @Override + protected IrStatement visitLoopStatement(LoopStatement node, Context context) + { + Context newContext = context.newScope(); + Optional label = getSqlLabel(newContext, node.getLabel()); + List statements = statements(node.getStatements(), newContext); + return new IrLoop(label, block(statements)); + } + + @Override + protected IrStatement visitReturnStatement(ReturnStatement node, Context context) + { + return new IrReturn(toRowExpression(context, node.getValue())); + } + + @Override + protected IrStatement visitAssignmentStatement(AssignmentStatement node, Context context) + { + Identifier name = node.getTarget(); + IrVariable target = context.variables().get(identifierValue(name)); + checkArgument(target != null, "Variable not declared in scope: %s", name); + return new IrSet(target, toRowExpression(context, node.getValue())); + } + + @Override + protected IrStatement visitIterateStatement(IterateStatement node, Context context) + { + return new IrContinue(label(context, node.getLabel())); + } + + @Override + protected IrStatement visitLeaveStatement(LeaveStatement node, Context context) + { + return new IrBreak(label(context, node.getLabel())); + } + + private static Optional getSqlLabel(Context context, Optional labelName) + { + return labelName.map(name -> { + IrLabel label = new IrLabel(identifierValue(name)); + verify(context.labels().put(identifierValue(name), label) == null, "Label already declared in this scope: %s", name); + return label; + }); + } + + private static IrLabel label(Context context, Identifier name) + { + IrLabel label = context.labels().get(identifierValue(name)); + checkArgument(label != null, "Label not defined: %s", name); + return label; + } + + private RowExpression toRowExpression(Context context, Expression expression) + { + // build symbol and field indexes for translation + TypeProvider typeProvider = TypeProvider.viewOf( + context.variables().entrySet().stream().collect(toImmutableMap( + entry -> new Symbol(entry.getKey()), + entry -> entry.getValue().type()))); + + List fields = context.variables().entrySet().stream() + .map(entry -> Field.newUnqualified(entry.getKey(), entry.getValue().type())) + .collect(toImmutableList()); + + Scope scope = Scope.builder() + .withRelationType(RelationId.of(expression), new RelationType(fields)) + .build(); + + SymbolAllocator symbolAllocator = new SymbolAllocator(); + List fieldSymbols = fields.stream() + .map(symbolAllocator::newSymbol) + .collect(toImmutableList()); + + Map, Symbol> nodeRefSymbolMap = buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator); + + // Apply casts, desugar expression, and preform other rewrites + TranslationMap translationMap = new TranslationMap(Optional.empty(), scope, analysis, nodeRefSymbolMap, fieldSymbols, session, plannerContext); + Expression translated = coerceIfNecessary(analysis, expression, translationMap.rewrite(expression)); + + // desugar the lambda captures + Expression lambdaCaptureDesugared = LambdaCaptureDesugaringRewriter.rewrite(translated, typeProvider, symbolAllocator); + + // The expression tree has been rewritten which breaks all the identity maps, so redo the analysis + // to re-analyze coercions that might be necessary + ExpressionAnalyzer analyzer = createExpressionAnalyzer(session, typeProvider); + analyzer.analyze(lambdaCaptureDesugared, scope); + + // optimize the expression + ExpressionInterpreter interpreter = new ExpressionInterpreter(lambdaCaptureDesugared, plannerContext, session, analyzer.getExpressionTypes()); + Expression optimized = new LiteralEncoder(plannerContext) + .toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), analyzer.getExpressionTypes().get(NodeRef.of(lambdaCaptureDesugared))); + + // validate expression + SugarFreeChecker.validate(optimized); + + // Analyze again after optimization + analyzer = createExpressionAnalyzer(session, typeProvider); + analyzer.analyze(optimized, scope); + + // translate to RowExpression + TranslationVisitor translator = new TranslationVisitor(plannerContext.getMetadata(), analyzer.getExpressionTypes(), ImmutableMap.of(), context.variables()); + RowExpression rowExpression = translator.process(optimized, null); + + // optimize RowExpression + ExpressionOptimizer optimizer = new ExpressionOptimizer(plannerContext.getMetadata(), plannerContext.getFunctionManager(), session); + rowExpression = optimizer.optimize(rowExpression); + + return rowExpression; + } + + public static Expression coerceIfNecessary(Analysis analysis, Expression original, Expression rewritten) + { + Type coercion = analysis.getCoercion(original); + if (coercion == null) { + return rewritten; + } + return new Cast(rewritten, toSqlType(coercion), false, analysis.isTypeOnlyCoercion(original)); + } + + private ExpressionAnalyzer createExpressionAnalyzer(Session session, TypeProvider typeProvider) + { + return ExpressionAnalyzer.createWithoutSubqueries( + plannerContext, + new AllowAllAccessControl(), + session, + typeProvider, + ImmutableMap.of(), + node -> new VerifyException("Unexpected subquery"), + warningCollector, + false); + } + + private List statements(List statements, Context context) + { + return statements.stream() + .map(statement -> process(statement, context)) + .collect(toImmutableList()); + } + + private static IrBlock block(List statements) + { + return new IrBlock(ImmutableList.of(), statements); + } + + private static String identifierValue(Identifier name) + { + // TODO: this should use getCanonicalValue() + return name.getValue(); + } + } + + private record Context(Map variables, Map labels) + { + public Context + { + variables = new LinkedHashMap<>(variables); + labels = new LinkedHashMap<>(labels); + } + + public Context newScope() + { + return new Context(variables, labels); + } + } + + private static class TranslationVisitor + extends SqlToRowExpressionTranslator.Visitor + { + private final Map variables; + + public TranslationVisitor( + Metadata metadata, + Map, Type> types, + Map layout, + Map variables) + { + super(metadata, types, layout); + this.variables = requireNonNull(variables, "variables is null"); + } + + @Override + protected RowExpression visitSymbolReference(SymbolReference node, Void context) + { + IrVariable variable = variables.get(node.getName()); + if (variable != null) { + return field(variable.field(), variable.type()); + } + return super.visitSymbolReference(node, context); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/DefaultIrNodeVisitor.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/DefaultIrNodeVisitor.java new file mode 100644 index 000000000000..1de92820096e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/DefaultIrNodeVisitor.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +public class DefaultIrNodeVisitor + implements IrNodeVisitor +{ + @Override + public Void visitRoutine(IrRoutine node, Void context) + { + for (IrVariable parameter : node.parameters()) { + process(parameter, context); + } + process(node.body(), context); + return null; + } + + @Override + public Void visitVariable(IrVariable node, Void context) + { + visitRowExpression(node.defaultValue()); + return null; + } + + @Override + public Void visitBlock(IrBlock node, Void context) + { + for (IrVariable variable : node.variables()) { + process(variable, context); + } + for (IrStatement statement : node.statements()) { + process(statement, context); + } + return null; + } + + @Override + public Void visitBreak(IrBreak node, Void context) + { + return null; + } + + @Override + public Void visitContinue(IrContinue node, Void context) + { + return null; + } + + @Override + public Void visitIf(IrIf node, Void context) + { + visitRowExpression(node.condition()); + process(node.ifTrue(), context); + if (node.ifFalse().isPresent()) { + process(node.ifFalse().get(), context); + } + return null; + } + + @Override + public Void visitWhile(IrWhile node, Void context) + { + visitRowExpression(node.condition()); + process(node.body(), context); + return null; + } + + @Override + public Void visitRepeat(IrRepeat node, Void context) + { + visitRowExpression(node.condition()); + process(node.block(), context); + return null; + } + + @Override + public Void visitLoop(IrLoop node, Void context) + { + process(node.block(), context); + return null; + } + + @Override + public Void visitReturn(IrReturn node, Void context) + { + visitRowExpression(node.value()); + return null; + } + + @Override + public Void visitSet(IrSet node, Void context) + { + visitRowExpression(node.value()); + process(node.target(), context); + return null; + } + + public void visitRowExpression(RowExpression expression) {} +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBlock.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBlock.java new file mode 100644 index 000000000000..6cbe622875dd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBlock.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record IrBlock(Optional label, List variables, List statements) + implements IrStatement +{ + public IrBlock(List variables, List statements) + { + this(Optional.empty(), variables, statements); + } + + public IrBlock + { + requireNonNull(label, "label is null"); + variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null")); + statements = ImmutableList.copyOf(requireNonNull(statements, "statements is null")); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitBlock(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBreak.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBreak.java new file mode 100644 index 000000000000..6c23f64f1f11 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrBreak.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import static java.util.Objects.requireNonNull; + +public record IrBreak(IrLabel target) + implements IrStatement +{ + public IrBreak + { + requireNonNull(target, "target is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitBreak(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrContinue.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrContinue.java new file mode 100644 index 000000000000..edae2b13e4ce --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrContinue.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import static java.util.Objects.requireNonNull; + +public record IrContinue(IrLabel target) + implements IrStatement +{ + public IrContinue + { + requireNonNull(target, "target is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitContinue(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrIf.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrIf.java new file mode 100644 index 000000000000..421e11880b3a --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrIf.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record IrIf(RowExpression condition, IrStatement ifTrue, Optional ifFalse) + implements IrStatement +{ + public IrIf + { + requireNonNull(condition, "condition is null"); + requireNonNull(ifTrue, "ifTrue is null"); + requireNonNull(ifFalse, "ifFalse is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitIf(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLabel.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLabel.java new file mode 100644 index 000000000000..0d8ae28d6034 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLabel.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import static java.util.Objects.requireNonNull; + +public record IrLabel(String name) +{ + public IrLabel + { + requireNonNull(name, "name is null"); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLoop.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLoop.java new file mode 100644 index 000000000000..c686a365e4c1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrLoop.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record IrLoop(Optional label, IrBlock block) + implements IrStatement +{ + public IrLoop + { + requireNonNull(label, "label is null"); + requireNonNull(block, "body is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitLoop(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNode.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNode.java new file mode 100644 index 000000000000..e7e2b2c58d7f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNode.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +public interface IrNode +{ + R accept(IrNodeVisitor visitor, C context); +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNodeVisitor.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNodeVisitor.java new file mode 100644 index 000000000000..3a0b9d1fb209 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrNodeVisitor.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +public interface IrNodeVisitor +{ + default R process(IrNode node, C context) + { + return node.accept(this, context); + } + + default R visitNode(IrNode node, C context) + { + return null; + } + + default R visitRoutine(IrRoutine node, C context) + { + return visitNode(node, context); + } + + default R visitVariable(IrVariable node, C context) + { + return visitNode(node, context); + } + + default R visitBlock(IrBlock node, C context) + { + return visitNode(node, context); + } + + default R visitBreak(IrBreak node, C context) + { + return visitNode(node, context); + } + + default R visitContinue(IrContinue node, C context) + { + return visitNode(node, context); + } + + default R visitIf(IrIf node, C context) + { + return visitNode(node, context); + } + + default R visitRepeat(IrRepeat node, C context) + { + return visitNode(node, context); + } + + default R visitLoop(IrLoop node, C context) + { + return visitNode(node, context); + } + + default R visitReturn(IrReturn node, C context) + { + return visitNode(node, context); + } + + default R visitSet(IrSet node, C context) + { + return visitNode(node, context); + } + + default R visitWhile(IrWhile node, C context) + { + return visitNode(node, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRepeat.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRepeat.java new file mode 100644 index 000000000000..527e37ea345a --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRepeat.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record IrRepeat(Optional label, RowExpression condition, IrBlock block) + implements IrStatement +{ + public IrRepeat + { + requireNonNull(label, "label is null"); + requireNonNull(condition, "condition is null"); + requireNonNull(block, "body is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitRepeat(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrReturn.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrReturn.java new file mode 100644 index 000000000000..e33b1d3763ae --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrReturn.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +import static java.util.Objects.requireNonNull; + +public record IrReturn(RowExpression value) + implements IrStatement +{ + public IrReturn + { + requireNonNull(value, "value is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitReturn(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRoutine.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRoutine.java new file mode 100644 index 000000000000..5fb5209cbb42 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrRoutine.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.spi.type.Type; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public record IrRoutine(Type returnType, List parameters, IrStatement body) + implements IrNode +{ + public IrRoutine + { + requireNonNull(returnType, "returnType is null"); + requireNonNull(parameters, "parameters is null"); + requireNonNull(body, "body is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitRoutine(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrSet.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrSet.java new file mode 100644 index 000000000000..f20dd3e6abfb --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrSet.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +import static java.util.Objects.requireNonNull; + +public record IrSet(IrVariable target, RowExpression value) + implements IrStatement +{ + public IrSet + { + requireNonNull(target, "target is null"); + requireNonNull(value, "value is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitSet(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrStatement.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrStatement.java new file mode 100644 index 000000000000..aa043071eb0c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrStatement.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME) +@JsonSubTypes({ + @JsonSubTypes.Type(value = IrBlock.class, name = "block"), + @JsonSubTypes.Type(value = IrBreak.class, name = "break"), + @JsonSubTypes.Type(value = IrContinue.class, name = "continue"), + @JsonSubTypes.Type(value = IrIf.class, name = "if"), + @JsonSubTypes.Type(value = IrLoop.class, name = "loop"), + @JsonSubTypes.Type(value = IrRepeat.class, name = "repeat"), + @JsonSubTypes.Type(value = IrReturn.class, name = "return"), + @JsonSubTypes.Type(value = IrSet.class, name = "set"), + @JsonSubTypes.Type(value = IrWhile.class, name = "while"), +}) +@SuppressWarnings("MarkerInterface") +public sealed interface IrStatement + extends IrNode + permits IrBlock, IrBreak, IrContinue, IrIf, IrLoop, IrRepeat, IrReturn, IrSet, IrWhile {} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrVariable.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrVariable.java new file mode 100644 index 000000000000..114ac3d55a16 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrVariable.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.spi.type.Type; +import io.trino.sql.relational.RowExpression; + +import static java.util.Objects.requireNonNull; + +public record IrVariable(int field, Type type, RowExpression defaultValue) + implements IrNode +{ + public IrVariable + { + requireNonNull(type, "type is null"); + requireNonNull(defaultValue, "value is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitVariable(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrWhile.java b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrWhile.java new file mode 100644 index 000000000000..f563facf3a77 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/routine/ir/IrWhile.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine.ir; + +import io.trino.sql.relational.RowExpression; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record IrWhile(Optional label, RowExpression condition, IrBlock body) + implements IrStatement +{ + public IrWhile + { + requireNonNull(label, "label is null"); + requireNonNull(condition, "condition is null"); + requireNonNull(body, "body is null"); + } + + @Override + public R accept(IrNodeVisitor visitor, C context) + { + return visitor.visitWhile(this, context); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlFunctions.java b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlFunctions.java new file mode 100644 index 000000000000..8807b83c3c39 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlFunctions.java @@ -0,0 +1,478 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import io.airlift.slice.Slice; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; +import io.trino.sql.PlannerContext; +import io.trino.sql.parser.SqlParser; +import io.trino.sql.routine.ir.IrRoutine; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.transaction.TransactionManager; +import org.assertj.core.api.ThrowingConsumer; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; + +import java.lang.invoke.MethodHandle; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.metadata.FunctionManager.createTestingFunctionManager; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static io.trino.transaction.TransactionBuilder.transaction; +import static io.trino.type.UnknownType.UNKNOWN; +import static java.lang.Math.floor; +import static org.assertj.core.api.Assertions.assertThat; + +class TestSqlFunctions +{ + private static final SqlParser SQL_PARSER = new SqlParser(); + private static final TransactionManager TRANSACTION_MANAGER = createTestTransactionManager(); + private static final PlannerContext PLANNER_CONTEXT = plannerContextBuilder() + .withTransactionManager(TRANSACTION_MANAGER) + .build(); + private static final Session SESSION = testSessionBuilder().build(); + + @Test + void testConstantReturn() + { + @Language("SQL") String sql = """ + FUNCTION answer() + RETURNS BIGINT + RETURN 42 + """; + assertFunction(sql, handle -> assertThat(handle.invoke()).isEqualTo(42L)); + } + + @Test + void testSimpleReturn() + { + @Language("SQL") String sql = """ + FUNCTION hello(s VARCHAR) + RETURNS VARCHAR + RETURN 'Hello, ' || s || '!' + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(utf8Slice("world"))).isEqualTo(utf8Slice("Hello, world!")); + assertThat(handle.invoke(utf8Slice("WORLD"))).isEqualTo(utf8Slice("Hello, WORLD!")); + }); + + testSingleExpression(VARCHAR, utf8Slice("foo"), VARCHAR, "Hello, foo!", "'Hello, ' || p || '!'"); + } + + @Test + void testSimpleExpression() + { + @Language("SQL") String sql = """ + FUNCTION test(a bigint) + RETURNS bigint + BEGIN + DECLARE x bigint DEFAULT CAST(99 AS bigint); + RETURN x * a; + END + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(0L)).isEqualTo(0L); + assertThat(handle.invoke(1L)).isEqualTo(99L); + assertThat(handle.invoke(42L)).isEqualTo(42L * 99); + assertThat(handle.invoke(123L)).isEqualTo(123L * 99); + }); + } + + @Test + void testSimpleCase() + { + @Language("SQL") String sql = """ + FUNCTION simple_case(a bigint) + RETURNS varchar + BEGIN + CASE a + WHEN 0 THEN RETURN 'zero'; + WHEN 1 THEN RETURN 'one'; + WHEN DECIMAL '10.0' THEN RETURN 'ten'; + WHEN 20.0E0 THEN RETURN 'twenty'; + ELSE RETURN 'other'; + END CASE; + RETURN NULL; + END + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(0L)).isEqualTo(utf8Slice("zero")); + assertThat(handle.invoke(1L)).isEqualTo(utf8Slice("one")); + assertThat(handle.invoke(10L)).isEqualTo(utf8Slice("ten")); + assertThat(handle.invoke(20L)).isEqualTo(utf8Slice("twenty")); + assertThat(handle.invoke(42L)).isEqualTo(utf8Slice("other")); + }); + } + + @Test + void testSearchCase() + { + @Language("SQL") String sql = """ + FUNCTION search_case(a bigint, b bigint) + RETURNS varchar + BEGIN + CASE + WHEN a = 0 THEN RETURN 'zero'; + WHEN b = 1 THEN RETURN 'one'; + WHEN a = DECIMAL '10.0' THEN RETURN 'ten'; + WHEN b = 20.0E0 THEN RETURN 'twenty'; + ELSE RETURN 'other'; + END CASE; + RETURN NULL; + END + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(0L, 42L)).isEqualTo(utf8Slice("zero")); + assertThat(handle.invoke(42L, 1L)).isEqualTo(utf8Slice("one")); + assertThat(handle.invoke(10L, 42L)).isEqualTo(utf8Slice("ten")); + assertThat(handle.invoke(42L, 20L)).isEqualTo(utf8Slice("twenty")); + assertThat(handle.invoke(42L, 42L)).isEqualTo(utf8Slice("other")); + + // verify ordering + assertThat(handle.invoke(0L, 1L)).isEqualTo(utf8Slice("zero")); + assertThat(handle.invoke(10L, 1L)).isEqualTo(utf8Slice("one")); + assertThat(handle.invoke(10L, 20L)).isEqualTo(utf8Slice("ten")); + assertThat(handle.invoke(42L, 20L)).isEqualTo(utf8Slice("twenty")); + }); + } + + @Test + void testFibonacciWhileLoop() + { + @Language("SQL") String sql = """ + FUNCTION fib(n bigint) + RETURNS bigint + BEGIN + DECLARE a, b bigint DEFAULT 1; + DECLARE c bigint; + IF n <= 2 THEN + RETURN 1; + END IF; + WHILE n > 2 DO + SET n = n - 1; + SET c = a + b; + SET a = b; + SET b = c; + END WHILE; + RETURN c; + END + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(1L)).isEqualTo(1L); + assertThat(handle.invoke(2L)).isEqualTo(1L); + assertThat(handle.invoke(3L)).isEqualTo(2L); + assertThat(handle.invoke(4L)).isEqualTo(3L); + assertThat(handle.invoke(5L)).isEqualTo(5L); + assertThat(handle.invoke(6L)).isEqualTo(8L); + assertThat(handle.invoke(7L)).isEqualTo(13L); + assertThat(handle.invoke(8L)).isEqualTo(21L); + }); + } + + @Test + void testBreakContinue() + { + @Language("SQL") String sql = """ + FUNCTION test() + RETURNS bigint + BEGIN + DECLARE a, b int DEFAULT 0; + top: WHILE a < 10 DO + SET a = a + 1; + IF a < 3 THEN + ITERATE top; + END IF; + SET b = b + 1; + IF a > 6 THEN + LEAVE top; + END IF; + END WHILE; + RETURN b; + END + """; + assertFunction(sql, handle -> assertThat(handle.invoke()).isEqualTo(5L)); + } + + @Test + void testRepeat() + { + @Language("SQL") String sql = """ + FUNCTION test_repeat(a bigint) + RETURNS bigint + BEGIN + REPEAT + SET a = a + 1; + UNTIL a >= 10 END REPEAT; + RETURN a; + END + """; + assertFunction(sql, handle -> { + assertThat(handle.invoke(0L)).isEqualTo(10L); + assertThat(handle.invoke(100L)).isEqualTo(101L); + }); + } + + @Test + void testRepeatContinue() + { + @Language("SQL") String sql = """ + FUNCTION test_repeat_continue() + RETURNS bigint + BEGIN + DECLARE a int DEFAULT 0; + DECLARE b int DEFAULT 0; + top: REPEAT + SET a = a + 1; + IF a <= 3 THEN + ITERATE top; + END IF; + SET b = b + 1; + UNTIL a >= 10 END REPEAT; + RETURN b; + END + """; + assertFunction(sql, handle -> assertThat(handle.invoke()).isEqualTo(7L)); + } + + @Test + void testReuseLabels() + { + @Language("SQL") String sql = """ + FUNCTION test() + RETURNS int + BEGIN + DECLARE r int DEFAULT 0; + abc: LOOP + SET r = r + 1; + LEAVE abc; + END LOOP; + abc: LOOP + SET r = r + 1; + LEAVE abc; + END LOOP; + RETURN r; + END + """; + assertFunction(sql, handle -> assertThat(handle.invoke()).isEqualTo(2L)); + } + + @Test + void testReuseVariables() + { + @Language("SQL") String sql = """ + FUNCTION test() + RETURNS bigint + BEGIN + DECLARE r bigint DEFAULT 0; + BEGIN + DECLARE x varchar DEFAULT 'hello'; + SET r = r + length(x); + END; + BEGIN + DECLARE x array(int) DEFAULT array[1, 2, 3]; + SET r = r + cardinality(x); + END; + RETURN r; + END + """; + assertFunction(sql, handle -> assertThat(handle.invoke()).isEqualTo(8L)); + } + + @Test + void testAssignParameter() + { + @Language("SQL") String sql = """ + FUNCTION test(x int) + RETURNS int + BEGIN + SET x = x * 3; + RETURN x; + END + """; + assertFunction(sql, handle -> assertThat(handle.invoke(2L)).isEqualTo(6L)); + } + + @Test + void testCall() + { + testSingleExpression(BIGINT, -123L, BIGINT, 123L, "abs(p)"); + } + + @Test + void testCallNested() + { + testSingleExpression(BIGINT, -123L, BIGINT, 123L, "abs(ceiling(p))"); + testSingleExpression(BIGINT, 42L, DOUBLE, 42.0, "to_unixTime(from_unixtime(p))"); + } + + @Test + void testArray() + { + testSingleExpression(BIGINT, 3L, BIGINT, 5L, "array[3,4,5,6,7][p]"); + testSingleExpression(BIGINT, 0L, BIGINT, 0L, "array_sort(array[3,2,4,5,1,p])[1]"); + } + + @Test + void testRow() + { + testSingleExpression(BIGINT, 8L, BIGINT, 8L, "ROW(1, 'a', p)[3]"); + } + + @Test + void testLambda() + { + testSingleExpression(BIGINT, 3L, BIGINT, 9L, "(transform(ARRAY [5, 6], x -> x + p)[2])", false); + } + + @Test + void testTry() + { + testSingleExpression(VARCHAR, utf8Slice("42"), BIGINT, 42L, "try(cast(p AS bigint))"); + testSingleExpression(VARCHAR, utf8Slice("abc"), BIGINT, null, "try(cast(p AS bigint))"); + } + + @Test + void testTryCast() + { + testSingleExpression(VARCHAR, utf8Slice("42"), BIGINT, 42L, "try_cast(p AS bigint)"); + testSingleExpression(VARCHAR, utf8Slice("abc"), BIGINT, null, "try_cast(p AS bigint)"); + } + + @Test + void testNonCanonical() + { + testSingleExpression(BIGINT, 100_000L, BIGINT, 1970L, "EXTRACT(YEAR FROM from_unixtime(p))"); + } + + @Test + void testAtTimeZone() + { + testSingleExpression(UNKNOWN, null, VARCHAR, "2012-10-30 18:00:00 America/Los_Angeles", "CAST(TIMESTAMP '2012-10-31 01:00 UTC' AT TIME ZONE 'America/Los_Angeles' AS VARCHAR)"); + } + + @Test + void testSession() + { + testSingleExpression(UNKNOWN, null, DOUBLE, floor(SESSION.getStart().toEpochMilli() / 1000.0), "floor(to_unixtime(localtimestamp))"); + testSingleExpression(UNKNOWN, null, VARCHAR, SESSION.getUser(), "current_user"); + } + + @Test + void testSpecialType() + { + testSingleExpression(VARCHAR, utf8Slice("abc"), BOOLEAN, true, "(p LIKE '%bc')"); + testSingleExpression(VARCHAR, utf8Slice("xb"), BOOLEAN, false, "(p LIKE '%bc')"); + testSingleExpression(VARCHAR, utf8Slice("abc"), BOOLEAN, false, "regexp_like(p, '\\d')"); + testSingleExpression(VARCHAR, utf8Slice("123"), BOOLEAN, true, "regexp_like(p, '\\d')"); + testSingleExpression(VARCHAR, utf8Slice("[4,5,6]"), VARCHAR, "6", "json_extract_scalar(p, '$[2]')"); + } + + private final AtomicLong nextId = new AtomicLong(); + + private void testSingleExpression(Type inputType, Object input, Type outputType, Object output, String expression) + { + testSingleExpression(inputType, input, outputType, output, expression, true); + } + + private void testSingleExpression(Type inputType, Object input, Type outputType, Object output, String expression, boolean deterministic) + { + @Language("SQL") String sql = "FUNCTION %s(p %s)\nRETURNS %s\n%s\nRETURN %s".formatted( + "test" + nextId.incrementAndGet(), + inputType.getTypeSignature(), + outputType.getTypeSignature(), + deterministic ? "DETERMINISTIC" : "NOT DETERMINISTIC", + expression); + + assertFunction(sql, handle -> { + Object result = handle.invoke(input); + + if ((outputType instanceof VarcharType) && (result instanceof Slice slice)) { + result = slice.toStringUtf8(); + } + + assertThat(result).isEqualTo(output); + }); + } + + private static void assertFunction(@Language("SQL") String sql, ThrowingConsumer consumer) + { + transaction(TRANSACTION_MANAGER, PLANNER_CONTEXT.getMetadata(), new AllowAllAccessControl()) + .singleStatement() + .execute(SESSION, session -> { + ScalarFunctionImplementation implementation = compileFunction(sql, session); + MethodHandle handle = implementation.getMethodHandle() + .bindTo(getInstance(implementation)) + .bindTo(session.toConnectorSession()); + consumer.accept(handle); + }); + } + + private static Object getInstance(ScalarFunctionImplementation implementation) + { + try { + return implementation.getInstanceFactory().orElseThrow().invoke(); + } + catch (Throwable t) { + throwIfUnchecked(t); + throw new RuntimeException(t); + } + } + + private static ScalarFunctionImplementation compileFunction(@Language("SQL") String sql, Session session) + { + FunctionSpecification function = SQL_PARSER.createFunctionSpecification(sql); + + FunctionMetadata metadata = SqlRoutineAnalyzer.extractFunctionMetadata(new FunctionId("test"), function); + + SqlRoutineAnalyzer analyzer = new SqlRoutineAnalyzer(PLANNER_CONTEXT, WarningCollector.NOOP); + SqlRoutineAnalysis analysis = analyzer.analyze(session, new AllowAllAccessControl(), function); + + SqlRoutinePlanner planner = new SqlRoutinePlanner(PLANNER_CONTEXT, WarningCollector.NOOP); + IrRoutine routine = planner.planSqlFunction(session, function, analysis); + + SqlRoutineCompiler compiler = new SqlRoutineCompiler(createTestingFunctionManager()); + SpecializedSqlScalarFunction sqlScalarFunction = compiler.compile(routine); + + InvocationConvention invocationConvention = new InvocationConvention( + metadata.getFunctionNullability().getArgumentNullable().stream() + .map(nullable -> nullable ? BOXED_NULLABLE : NEVER_NULL) + .toList(), + metadata.getFunctionNullability().isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, + true, + true); + + return sqlScalarFunction.getScalarFunctionImplementation(invocationConvention); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineAnalyzer.java new file mode 100644 index 000000000000..4d8c32592648 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineAnalyzer.java @@ -0,0 +1,538 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import io.trino.execution.warnings.WarningCollector; +import io.trino.security.AllowAllAccessControl; +import io.trino.sql.PlannerContext; +import io.trino.sql.parser.SqlParser; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.testing.assertions.TrinoExceptionAssert; +import io.trino.transaction.TestingTransactionManager; +import io.trino.transaction.TransactionManager; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; + +import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static io.trino.spi.StandardErrorCode.MISSING_RETURN; +import static io.trino.spi.StandardErrorCode.NOT_FOUND; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.StandardErrorCode.SYNTAX_ERROR; +import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; +import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; +import static io.trino.testing.TestingSession.testSession; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static io.trino.transaction.TransactionBuilder.transaction; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.from; + +class TestSqlRoutineAnalyzer +{ + private static final SqlParser SQL_PARSER = new SqlParser(); + + @Test + void testParameters() + { + assertFails("FUNCTION test(x) RETURNS int RETURN 123") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessage("line 1:15: Function parameters must have a name"); + + assertFails("FUNCTION test(x int, y int, x bigint) RETURNS int RETURN 123") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessage("line 1:29: Duplicate function parameter name: x"); + } + + @Test + void testCharacteristics() + { + assertFails("FUNCTION test() RETURNS int CALLED ON NULL INPUT CALLED ON NULL INPUT RETURN 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:1: Multiple null-call clauses specified"); + + assertFails("FUNCTION test() RETURNS int RETURNS NULL ON NULL INPUT CALLED ON NULL INPUT RETURN 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:1: Multiple null-call clauses specified"); + + assertFails("FUNCTION test() RETURNS int COMMENT 'abc' COMMENT 'xyz' RETURN 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:1: Multiple comment clauses specified"); + + assertFails("FUNCTION test() RETURNS int LANGUAGE abc LANGUAGE xyz RETURN 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:1: Multiple language clauses specified"); + + assertFails("FUNCTION test() RETURNS int NOT DETERMINISTIC DETERMINISTIC RETURN 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:1: Multiple deterministic clauses specified"); + } + + @Test + void testParameterTypeUnknown() + { + assertFails("FUNCTION test(x abc) RETURNS int RETURN 123") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:15: Unknown type: abc"); + } + + @Test + void testReturnTypeUnknown() + { + assertFails("FUNCTION test() RETURNS abc RETURN 123") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:17: Unknown type: abc"); + } + + @Test + void testReturnType() + { + analyze("FUNCTION test() RETURNS bigint RETURN smallint '123'"); + analyze("FUNCTION test() RETURNS varchar(10) RETURN 'test'"); + + assertFails("FUNCTION test() RETURNS varchar(2) RETURN 'test'") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:43: Value of RETURN must evaluate to varchar(2) (actual: varchar(4))"); + + assertFails("FUNCTION test() RETURNS bigint RETURN random()") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:39: Value of RETURN must evaluate to bigint (actual: double)"); + + assertFails("FUNCTION test() RETURNS real RETURN random()") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:37: Value of RETURN must evaluate to real (actual: double)"); + } + + @Test + void testLanguage() + { + assertThat(analyze("FUNCTION test() RETURNS bigint LANGUAGE SQL RETURN abs(-42)")) + .returns(true, from(SqlRoutineAnalysis::deterministic)); + + assertFails("FUNCTION test() RETURNS bigint LANGUAGE JAVASCRIPT RETURN abs(-42)") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:1: Unsupported language: JAVASCRIPT"); + } + + @Test + void testDeterministic() + { + assertThat(analyze("FUNCTION test() RETURNS bigint RETURN abs(-42)")) + .returns(true, from(SqlRoutineAnalysis::deterministic)); + + assertThat(analyze("FUNCTION test() RETURNS bigint DETERMINISTIC RETURN abs(-42)")) + .returns(true, from(SqlRoutineAnalysis::deterministic)); + + assertFails("FUNCTION test() RETURNS bigint NOT DETERMINISTIC RETURN abs(-42)") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessage("line 1:1: Deterministic function declared NOT DETERMINISTIC"); + + assertThat(analyze("FUNCTION test() RETURNS varchar RETURN reverse('test')")) + .returns(true, from(SqlRoutineAnalysis::deterministic)); + + assertThat(analyze("FUNCTION test() RETURNS double NOT DETERMINISTIC RETURN 42 * random()")) + .returns(false, from(SqlRoutineAnalysis::deterministic)); + + assertFails("FUNCTION test() RETURNS double RETURN 42 * random()") + .hasErrorCode(INVALID_ARGUMENTS) + .hasMessage("line 1:1: Non-deterministic function declared DETERMINISTIC"); + } + + @Test + void testIfConditionType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + IF random() THEN + RETURN 13; + END IF; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 3:6: Condition of IF statement must evaluate to boolean (actual: double)"); + } + + @Test + void testElseIfConditionType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + IF false THEN + RETURN 13; + ELSEIF random() THEN + RETURN 13; + END IF; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 5:10: Condition of ELSEIF clause must evaluate to boolean (actual: double)"); + } + + @Test + void testCaseWhenClauseValueType() + { + assertFails(""" + FUNCTION test(x int) RETURNS int + BEGIN + CASE x + WHEN 13 THEN RETURN 13; + WHEN 'abc' THEN RETURN 42; + END CASE; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 5:10: WHEN clause value must evaluate to CASE value type integer (actual: varchar(3))"); + } + + @Test + void testCaseWhenClauseConditionType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + CASE + WHEN true THEN RETURN 42; + WHEN 13 THEN RETURN 13; + END CASE; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 5:10: Condition of WHEN clause must evaluate to boolean (actual: integer)"); + } + + @Test + void testMissingReturn() + { + assertFails("FUNCTION test() RETURNS int BEGIN END") + .hasErrorCode(MISSING_RETURN) + .hasMessage("line 1:29: Function must end in a RETURN statement"); + + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + IF false THEN + RETURN 13; + END IF; + END + """) + .hasErrorCode(MISSING_RETURN) + .hasMessage("line 2:1: Function must end in a RETURN statement"); + } + + @Test + void testBadVariableDefault() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x int DEFAULT 'abc'; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 3:25: Value of DEFAULT must evaluate to integer (actual: varchar(3))"); + } + + @Test + void testVariableAlreadyDeclared() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x int; + DECLARE x int; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 4:11: Variable already declared in this scope: x"); + + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x int; + DECLARE y, x int; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 4:14: Variable already declared in this scope: x"); + + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x, y, x int; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 3:17: Variable already declared in this scope: x"); + + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x int; + BEGIN + DECLARE x int; + END; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 5:13: Variable already declared in this scope: x"); + + analyze(""" + FUNCTION test() RETURNS int + BEGIN + BEGIN + DECLARE x int; + END; + BEGIN + DECLARE x varchar; + END; + RETURN 0; + END + """); + } + + @Test + void testAssignmentUnknownTarget() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + SET x = 13; + RETURN 0; + END + """) + .hasErrorCode(NOT_FOUND) + .hasMessage("line 3:7: Variable cannot be resolved: x"); + } + + @Test + void testAssignmentType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + DECLARE x int; + SET x = 'abc'; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 4:11: Value of SET 'x' must evaluate to integer (actual: varchar(3))"); + } + + @Test + void testWhileConditionType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + WHILE 13 DO + RETURN 0; + END WHILE; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 3:9: Condition of WHILE statement must evaluate to boolean (actual: integer)"); + } + + @Test + void testUntilConditionType() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + REPEAT + RETURN 42; + UNTIL 13 END REPEAT; + RETURN 0; + END + """) + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 5:9: Condition of REPEAT statement must evaluate to boolean (actual: integer)"); + } + + @Test + void testIterateUnknownLabel() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + WHILE true DO + ITERATE abc; + END WHILE; + RETURN 0; + END + """) + .hasErrorCode(NOT_FOUND) + .hasMessage("line 4:13: Label not defined: abc"); + } + + @Test + void testLeaveUnknownLabel() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + LEAVE abc; + RETURN 0; + END + """) + .hasErrorCode(NOT_FOUND) + .hasMessage("line 3:9: Label not defined: abc"); + } + + @Test + void testDuplicateWhileLabel() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + abc: WHILE true DO + LEAVE abc; + abc: WHILE true DO + LEAVE abc; + END WHILE; + END WHILE; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 5:5: Label already declared in this scope: abc"); + + analyze(""" + FUNCTION test() RETURNS int + BEGIN + abc: WHILE true DO + LEAVE abc; + END WHILE; + abc: WHILE true DO + LEAVE abc; + END WHILE; + RETURN 0; + END + """); + } + + @Test + void testDuplicateRepeatLabel() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + abc: REPEAT + LEAVE abc; + abc: REPEAT + LEAVE abc; + UNTIL true END REPEAT; + UNTIL true END REPEAT; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 5:5: Label already declared in this scope: abc"); + + analyze(""" + FUNCTION test() RETURNS int + BEGIN + abc: REPEAT + LEAVE abc; + UNTIL true END REPEAT; + abc: REPEAT + LEAVE abc; + UNTIL true END REPEAT; + RETURN 0; + END + """); + } + + @Test + void testDuplicateLoopLabel() + { + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + abc: LOOP + LEAVE abc; + abc: LOOP + LEAVE abc; + END LOOP; + END LOOP; + RETURN 0; + END + """) + .hasErrorCode(ALREADY_EXISTS) + .hasMessage("line 5:5: Label already declared in this scope: abc"); + + analyze(""" + FUNCTION test() RETURNS int + BEGIN + abc: LOOP + LEAVE abc; + END LOOP; + abc: LOOP + LEAVE abc; + END LOOP; + RETURN 0; + END + """); + } + + @Test + void testSubquery() + { + assertFails("FUNCTION test() RETURNS int RETURN (SELECT 123)") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:36: Queries are not allowed in functions"); + + assertFails(""" + FUNCTION test() RETURNS int + BEGIN + RETURN (SELECT 123); + END + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 3:10: Queries are not allowed in functions"); + } + + private static TrinoExceptionAssert assertFails(@Language("SQL") String function) + { + return assertTrinoExceptionThrownBy(() -> analyze(function)); + } + + private static SqlRoutineAnalysis analyze(@Language("SQL") String function) + { + FunctionSpecification specification = SQL_PARSER.createFunctionSpecification(function); + + TransactionManager transactionManager = new TestingTransactionManager(); + PlannerContext plannerContext = plannerContextBuilder() + .withTransactionManager(transactionManager) + .build(); + return transaction(transactionManager, plannerContext.getMetadata(), new AllowAllAccessControl()) + .singleStatement() + .execute(testSession(), transactionSession -> { + SqlRoutineAnalyzer analyzer = new SqlRoutineAnalyzer(plannerContext, WarningCollector.NOOP); + return analyzer.analyze(transactionSession, new AllowAllAccessControl(), specification); + }); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineCompiler.java b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineCompiler.java new file mode 100644 index 000000000000..5771adb312e2 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineCompiler.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.routine; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.spi.function.OperatorType; +import io.trino.spi.type.Type; +import io.trino.sql.relational.InputReferenceExpression; +import io.trino.sql.relational.RowExpression; +import io.trino.sql.routine.ir.IrBlock; +import io.trino.sql.routine.ir.IrBreak; +import io.trino.sql.routine.ir.IrContinue; +import io.trino.sql.routine.ir.IrIf; +import io.trino.sql.routine.ir.IrLabel; +import io.trino.sql.routine.ir.IrLoop; +import io.trino.sql.routine.ir.IrRepeat; +import io.trino.sql.routine.ir.IrReturn; +import io.trino.sql.routine.ir.IrRoutine; +import io.trino.sql.routine.ir.IrSet; +import io.trino.sql.routine.ir.IrStatement; +import io.trino.sql.routine.ir.IrVariable; +import io.trino.sql.routine.ir.IrWhile; +import io.trino.util.Reflection; +import org.junit.jupiter.api.Test; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.trino.spi.function.OperatorType.ADD; +import static io.trino.spi.function.OperatorType.LESS_THAN; +import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static io.trino.spi.function.OperatorType.MULTIPLY; +import static io.trino.spi.function.OperatorType.SUBTRACT; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; +import static io.trino.sql.relational.Expressions.call; +import static io.trino.sql.relational.Expressions.constant; +import static io.trino.sql.relational.Expressions.constantNull; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.util.Reflection.constructorMethodHandle; +import static java.util.Arrays.stream; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestSqlRoutineCompiler +{ + private static final Session TEST_SESSION = testSessionBuilder().build(); + private final SqlRoutineCompiler compiler = new SqlRoutineCompiler(PLANNER_CONTEXT.getFunctionManager()); + + @Test + public void testSimpleExpression() + throws Throwable + { + // CREATE FUNCTION test(a bigint) + // RETURNS bigint + // BEGIN + // DECLARE x bigint DEFAULT 99; + // RETURN x * a; + // END + + IrVariable arg = new IrVariable(0, BIGINT, constantNull(BIGINT)); + IrVariable variable = new IrVariable(1, BIGINT, constant(99L, BIGINT)); + + ResolvedFunction multiply = operator(MULTIPLY, BIGINT, BIGINT); + + IrRoutine routine = new IrRoutine( + BIGINT, + parameters(arg), + new IrBlock(variables(variable), statements( + new IrSet(variable, call(multiply, reference(variable), reference(arg))), + new IrReturn(reference(variable))))); + + MethodHandle handle = compile(routine); + + assertThat(handle.invoke(0L)).isEqualTo(0L); + assertThat(handle.invoke(1L)).isEqualTo(99L); + assertThat(handle.invoke(42L)).isEqualTo(42L * 99); + assertThat(handle.invoke(123L)).isEqualTo(123L * 99); + } + + @Test + public void testFibonacciWhileLoop() + throws Throwable + { + // CREATE FUNCTION fib(n bigint) + // RETURNS bigint + // BEGIN + // DECLARE a bigint DEFAULT 1; + // DECLARE b bigint DEFAULT 1; + // DECLARE c bigint; + // + // IF n <= 2 THEN + // RETURN 1; + // END IF; + // + // WHILE n > 2 DO + // SET n = n - 1; + // SET c = a + b; + // SET a = b; + // SET b = c; + // END WHILE; + // + // RETURN c; + // END + + IrVariable n = new IrVariable(0, BIGINT, constantNull(BIGINT)); + IrVariable a = new IrVariable(1, BIGINT, constant(1L, BIGINT)); + IrVariable b = new IrVariable(2, BIGINT, constant(1L, BIGINT)); + IrVariable c = new IrVariable(3, BIGINT, constantNull(BIGINT)); + + ResolvedFunction add = operator(ADD, BIGINT, BIGINT); + ResolvedFunction subtract = operator(SUBTRACT, BIGINT, BIGINT); + ResolvedFunction lessThan = operator(LESS_THAN, BIGINT, BIGINT); + ResolvedFunction lessThanOrEqual = operator(LESS_THAN_OR_EQUAL, BIGINT, BIGINT); + + IrRoutine routine = new IrRoutine( + BIGINT, + parameters(n), + new IrBlock(variables(a, b, c), statements( + new IrIf( + call(lessThanOrEqual, reference(n), constant(2L, BIGINT)), + new IrReturn(constant(1L, BIGINT)), + Optional.empty()), + new IrWhile( + Optional.empty(), + call(lessThan, constant(2L, BIGINT), reference(n)), + new IrBlock( + variables(), + statements( + new IrSet(n, call(subtract, reference(n), constant(1L, BIGINT))), + new IrSet(c, call(add, reference(a), reference(b))), + new IrSet(a, reference(b)), + new IrSet(b, reference(c))))), + new IrReturn(reference(c))))); + + MethodHandle handle = compile(routine); + + assertThat(handle.invoke(1L)).isEqualTo(1L); + assertThat(handle.invoke(2L)).isEqualTo(1L); + assertThat(handle.invoke(3L)).isEqualTo(2L); + assertThat(handle.invoke(4L)).isEqualTo(3L); + assertThat(handle.invoke(5L)).isEqualTo(5L); + assertThat(handle.invoke(6L)).isEqualTo(8L); + assertThat(handle.invoke(7L)).isEqualTo(13L); + assertThat(handle.invoke(8L)).isEqualTo(21L); + } + + @Test + public void testBreakContinue() + throws Throwable + { + // CREATE FUNCTION test() + // RETURNS bigint + // BEGIN + // DECLARE a bigint DEFAULT 0; + // DECLARE b bigint DEFAULT 0; + // + // top: WHILE a < 10 DO + // SET a = a + 1; + // IF a < 3 THEN + // ITERATE top; + // END IF; + // SET b = b + 1; + // IF a > 6 THEN + // LEAVE top; + // END IF; + // END WHILE; + // + // RETURN b; + // END + + IrVariable a = new IrVariable(0, BIGINT, constant(0L, BIGINT)); + IrVariable b = new IrVariable(1, BIGINT, constant(0L, BIGINT)); + + ResolvedFunction add = operator(ADD, BIGINT, BIGINT); + ResolvedFunction lessThan = operator(LESS_THAN, BIGINT, BIGINT); + + IrLabel label = new IrLabel("test"); + + IrRoutine routine = new IrRoutine( + BIGINT, + parameters(), + new IrBlock(variables(a, b), statements( + new IrWhile( + Optional.of(label), + call(lessThan, reference(a), constant(10L, BIGINT)), + new IrBlock(variables(), statements( + new IrSet(a, call(add, reference(a), constant(1L, BIGINT))), + new IrIf( + call(lessThan, reference(a), constant(3L, BIGINT)), + new IrContinue(label), + Optional.empty()), + new IrSet(b, call(add, reference(b), constant(1L, BIGINT))), + new IrIf( + call(lessThan, constant(6L, BIGINT), reference(a)), + new IrBreak(label), + Optional.empty())))), + new IrReturn(reference(b))))); + + MethodHandle handle = compile(routine); + + assertThat(handle.invoke()).isEqualTo(5L); + } + + @Test + public void testInterruptionWhile() + throws Throwable + { + assertRoutineInterruption(() -> new IrWhile( + Optional.empty(), + constant(true, BOOLEAN), + new IrBlock(variables(), statements()))); + } + + @Test + public void testInterruptionRepeat() + throws Throwable + { + assertRoutineInterruption(() -> new IrRepeat( + Optional.empty(), + constant(false, BOOLEAN), + new IrBlock(variables(), statements()))); + } + + @Test + public void testInterruptionLoop() + throws Throwable + { + assertRoutineInterruption(() -> new IrLoop( + Optional.empty(), + new IrBlock(variables(), statements()))); + } + + private void assertRoutineInterruption(Supplier loopFactory) + throws Throwable + { + IrRoutine routine = new IrRoutine( + BIGINT, + parameters(), + new IrBlock(variables(), statements( + loopFactory.get(), + new IrReturn(constant(null, BIGINT))))); + + MethodHandle handle = compile(routine); + + AtomicBoolean interrupted = new AtomicBoolean(); + Thread thread = new Thread(() -> { + assertThatThrownBy(handle::invoke) + .hasMessageContaining("Thread interrupted"); + interrupted.set(true); + }); + thread.start(); + thread.interrupt(); + thread.join(TimeUnit.SECONDS.toMillis(10)); + assertThat(interrupted).isTrue(); + } + + private MethodHandle compile(IrRoutine routine) + throws Throwable + { + Class clazz = compiler.compileClass(routine); + + MethodHandle handle = stream(clazz.getMethods()) + .filter(method -> method.getName().equals("run")) + .map(Reflection::methodHandle) + .collect(onlyElement()); + + Object instance = constructorMethodHandle(clazz).invoke(); + + return handle.bindTo(instance).bindTo(TEST_SESSION.toConnectorSession()); + } + + private static List parameters(IrVariable... variables) + { + return ImmutableList.copyOf(variables); + } + + private static List variables(IrVariable... variables) + { + return ImmutableList.copyOf(variables); + } + + private static List statements(IrStatement... statements) + { + return ImmutableList.copyOf(statements); + } + + private static RowExpression reference(IrVariable variable) + { + return new InputReferenceExpression(variable.field(), variable.type()); + } + + private static ResolvedFunction operator(OperatorType operator, Type... argumentTypes) + { + return PLANNER_CONTEXT.getMetadata().resolveOperator(operator, ImmutableList.copyOf(argumentTypes)); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java index b8b86d57f99c..194de410efdc 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java +++ b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java @@ -270,6 +270,7 @@ public static Query singleValueQuery(String columnName, boolean value) public static Query query(QueryBody body) { return new Query( + ImmutableList.of(), Optional.empty(), body, Optional.empty(), diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 3ef1f3341850..9f17c3a0ebcb 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -20,13 +20,20 @@ import io.trino.sql.tree.AliasedRelation; import io.trino.sql.tree.AllColumns; import io.trino.sql.tree.Analyze; +import io.trino.sql.tree.AssignmentStatement; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.Call; import io.trino.sql.tree.CallArgument; +import io.trino.sql.tree.CaseStatement; +import io.trino.sql.tree.CaseStatementWhenClause; import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.Comment; +import io.trino.sql.tree.CommentCharacteristic; import io.trino.sql.tree.Commit; +import io.trino.sql.tree.CompoundStatement; +import io.trino.sql.tree.ControlStatement; import io.trino.sql.tree.CreateCatalog; +import io.trino.sql.tree.CreateFunction; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateRole; import io.trino.sql.tree.CreateSchema; @@ -38,13 +45,17 @@ import io.trino.sql.tree.Deny; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.DescribeOutput; +import io.trino.sql.tree.DeterministicCharacteristic; import io.trino.sql.tree.DropCatalog; import io.trino.sql.tree.DropColumn; +import io.trino.sql.tree.DropFunction; import io.trino.sql.tree.DropMaterializedView; import io.trino.sql.tree.DropRole; import io.trino.sql.tree.DropSchema; import io.trino.sql.tree.DropTable; import io.trino.sql.tree.DropView; +import io.trino.sql.tree.ElseClause; +import io.trino.sql.tree.ElseIfClause; import io.trino.sql.tree.Except; import io.trino.sql.tree.Execute; import io.trino.sql.tree.ExecuteImmediate; @@ -55,13 +66,16 @@ import io.trino.sql.tree.ExplainType; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FetchFirst; +import io.trino.sql.tree.FunctionSpecification; import io.trino.sql.tree.Grant; import io.trino.sql.tree.GrantRoles; import io.trino.sql.tree.GrantorSpecification; import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.IfStatement; import io.trino.sql.tree.Insert; import io.trino.sql.tree.Intersect; import io.trino.sql.tree.Isolation; +import io.trino.sql.tree.IterateStatement; import io.trino.sql.tree.Join; import io.trino.sql.tree.JoinCriteria; import io.trino.sql.tree.JoinOn; @@ -69,9 +83,12 @@ import io.trino.sql.tree.JsonTable; import io.trino.sql.tree.JsonTableColumnDefinition; import io.trino.sql.tree.JsonTableDefaultPlan; +import io.trino.sql.tree.LanguageCharacteristic; import io.trino.sql.tree.Lateral; +import io.trino.sql.tree.LeaveStatement; import io.trino.sql.tree.LikeClause; import io.trino.sql.tree.Limit; +import io.trino.sql.tree.LoopStatement; import io.trino.sql.tree.Merge; import io.trino.sql.tree.MergeCase; import io.trino.sql.tree.MergeDelete; @@ -80,9 +97,11 @@ import io.trino.sql.tree.NaturalJoin; import io.trino.sql.tree.NestedColumns; import io.trino.sql.tree.Node; +import io.trino.sql.tree.NullInputCharacteristic; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.OrdinalityColumn; +import io.trino.sql.tree.ParameterDeclaration; import io.trino.sql.tree.PatternRecognitionRelation; import io.trino.sql.tree.PlanLeaf; import io.trino.sql.tree.PlanParentChild; @@ -102,14 +121,19 @@ import io.trino.sql.tree.RenameSchema; import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; +import io.trino.sql.tree.RepeatStatement; import io.trino.sql.tree.ResetSession; import io.trino.sql.tree.ResetSessionAuthorization; +import io.trino.sql.tree.ReturnStatement; +import io.trino.sql.tree.ReturnsClause; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; +import io.trino.sql.tree.RoutineCharacteristic; import io.trino.sql.tree.Row; import io.trino.sql.tree.RowPattern; import io.trino.sql.tree.SampledRelation; +import io.trino.sql.tree.SecurityCharacteristic; import io.trino.sql.tree.Select; import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.SetColumnType; @@ -151,6 +175,8 @@ import io.trino.sql.tree.UpdateAssignment; import io.trino.sql.tree.ValueColumn; import io.trino.sql.tree.Values; +import io.trino.sql.tree.VariableDeclaration; +import io.trino.sql.tree.WhileStatement; import io.trino.sql.tree.WithQuery; import java.util.ArrayList; @@ -170,6 +196,7 @@ import static io.trino.sql.RowPatternFormatter.formatPattern; import static io.trino.sql.tree.SaveMode.IGNORE; import static io.trino.sql.tree.SaveMode.REPLACE; +import static java.lang.String.join; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; @@ -613,6 +640,18 @@ protected Void visitDescribeInput(DescribeInput node, Integer indent) @Override protected Void visitQuery(Query node, Integer indent) { + if (!node.getFunctions().isEmpty()) { + builder.append("WITH\n"); + Iterator functions = node.getFunctions().iterator(); + while (functions.hasNext()) { + process(functions.next(), indent + 1); + if (functions.hasNext()) { + builder.append(','); + } + builder.append('\n'); + } + } + node.getWith().ifPresent(with -> { append(indent, "WITH"); if (with.isRecursive()) { @@ -2094,7 +2133,7 @@ public Void visitGrant(Grant node, Integer indent) builder.append("GRANT "); builder.append(node.getPrivileges() - .map(privileges -> String.join(", ", privileges)) + .map(privileges -> join(", ", privileges)) .orElse("ALL PRIVILEGES")); builder.append(" ON "); @@ -2117,7 +2156,7 @@ public Void visitDeny(Deny node, Integer indent) builder.append("DENY "); if (node.getPrivileges().isPresent()) { - builder.append(String.join(", ", node.getPrivileges().get())); + builder.append(join(", ", node.getPrivileges().get())); } else { builder.append("ALL PRIVILEGES"); @@ -2145,7 +2184,7 @@ public Void visitRevoke(Revoke node, Integer indent) } builder.append(node.getPrivileges() - .map(privileges -> String.join(", ", privileges)) + .map(privileges -> join(", ", privileges)) .orElse("ALL PRIVILEGES")); builder.append(" ON "); @@ -2216,6 +2255,301 @@ public Void visitSetTimeZone(SetTimeZone node, Integer indent) return null; } + @Override + protected Void visitCreateFunction(CreateFunction node, Integer indent) + { + builder.append("CREATE "); + if (node.isReplace()) { + builder.append("OR REPLACE "); + } + process(node.getSpecification(), indent); + return null; + } + + @Override + protected Void visitDropFunction(DropFunction node, Integer indent) + { + builder.append("DROP FUNCTION "); + if (node.isExists()) { + builder.append("IF EXISTS "); + } + builder.append(formatName(node.getName())); + processParameters(node.getParameters(), indent); + return null; + } + + @Override + protected Void visitFunctionSpecification(FunctionSpecification node, Integer indent) + { + append(indent, "FUNCTION ") + .append(formatName(node.getName())); + processParameters(node.getParameters(), indent); + builder.append("\n"); + process(node.getReturnsClause(), indent); + builder.append("\n"); + for (RoutineCharacteristic characteristic : node.getRoutineCharacteristics()) { + process(characteristic, indent); + builder.append("\n"); + } + process(node.getStatement(), indent); + return null; + } + + @Override + protected Void visitParameterDeclaration(ParameterDeclaration node, Integer indent) + { + node.getName().ifPresent(value -> + builder.append(formatName(value)).append(" ")); + builder.append(formatExpression(node.getType())); + return null; + } + + @Override + protected Void visitLanguageCharacteristic(LanguageCharacteristic node, Integer indent) + { + append(indent, "LANGUAGE ") + .append(formatName(node.getLanguage())); + return null; + } + + @Override + protected Void visitDeterministicCharacteristic(DeterministicCharacteristic node, Integer indent) + { + append(indent, (node.isDeterministic() ? "" : "NOT ") + "DETERMINISTIC"); + return null; + } + + @Override + protected Void visitNullInputCharacteristic(NullInputCharacteristic node, Integer indent) + { + if (node.isCalledOnNull()) { + append(indent, "CALLED ON NULL INPUT"); + } + else { + append(indent, "RETURNS NULL ON NULL INPUT"); + } + return null; + } + + @Override + protected Void visitSecurityCharacteristic(SecurityCharacteristic node, Integer indent) + { + append(indent, "SECURITY ") + .append(node.getSecurity().name()); + return null; + } + + @Override + protected Void visitCommentCharacteristic(CommentCharacteristic node, Integer indent) + { + append(indent, "COMMENT ") + .append(formatStringLiteral(node.getComment())); + return null; + } + + @Override + protected Void visitReturnClause(ReturnsClause node, Integer indent) + { + append(indent, "RETURNS ") + .append(formatExpression(node.getReturnType())); + return null; + } + + @Override + protected Void visitReturnStatement(ReturnStatement node, Integer indent) + { + append(indent, "RETURN ") + .append(formatExpression(node.getValue())); + return null; + } + + @Override + protected Void visitCompoundStatement(CompoundStatement node, Integer indent) + { + append(indent, "BEGIN\n"); + for (VariableDeclaration variableDeclaration : node.getVariableDeclarations()) { + process(variableDeclaration, indent + 1); + builder.append(";\n"); + } + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + append(indent, "END"); + return null; + } + + @Override + protected Void visitVariableDeclaration(VariableDeclaration node, Integer indent) + { + append(indent, "DECLARE ") + .append(node.getNames().stream() + .map(SqlFormatter::formatName) + .collect(joining(", "))) + .append(" ") + .append(formatExpression(node.getType())); + if (node.getDefaultValue().isPresent()) { + builder.append(" DEFAULT ") + .append(formatExpression(node.getDefaultValue().get())); + } + return null; + } + + @Override + protected Void visitAssignmentStatement(AssignmentStatement node, Integer indent) + { + append(indent, "SET "); + builder.append(formatName(node.getTarget())) + .append(" = ") + .append(formatExpression(node.getValue())); + return null; + } + + @Override + protected Void visitCaseStatement(CaseStatement node, Integer indent) + { + append(indent, "CASE"); + if (node.getExpression().isPresent()) { + builder.append(" ") + .append(formatExpression(node.getExpression().get())); + } + builder.append("\n"); + for (CaseStatementWhenClause whenClause : node.getWhenClauses()) { + process(whenClause, indent + 1); + } + if (node.getElseClause().isPresent()) { + process(node.getElseClause().get(), indent + 1); + } + append(indent, "END CASE"); + return null; + } + + @Override + protected Void visitCaseStatementWhenClause(CaseStatementWhenClause node, Integer indent) + { + append(indent, "WHEN ") + .append(formatExpression(node.getExpression())) + .append(" THEN\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + return null; + } + + @Override + protected Void visitIfStatement(IfStatement node, Integer indent) + { + append(indent, "IF ") + .append(formatExpression(node.getExpression())) + .append(" THEN\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + for (ElseIfClause elseIfClause : node.getElseIfClauses()) { + process(elseIfClause, indent); + } + if (node.getElseClause().isPresent()) { + process(node.getElseClause().get(), indent); + } + append(indent, "END IF"); + return null; + } + + @Override + protected Void visitElseIfClause(ElseIfClause node, Integer indent) + { + append(indent, "ELSEIF ") + .append(formatExpression(node.getExpression())) + .append(" THEN\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + return null; + } + + @Override + protected Void visitElseClause(ElseClause node, Integer indent) + { + append(indent, "ELSE\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + return null; + } + + @Override + protected Void visitIterateStatement(IterateStatement node, Integer indent) + { + append(indent, "ITERATE ") + .append(formatName(node.getLabel())); + return null; + } + + @Override + protected Void visitLeaveStatement(LeaveStatement node, Integer indent) + { + append(indent, "LEAVE ") + .append(formatName(node.getLabel())); + return null; + } + + @Override + protected Void visitLoopStatement(LoopStatement node, Integer indent) + { + builder.append(indentString(indent)); + appendBeginLabel(node.getLabel()); + builder.append("LOOP\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + append(indent, "END LOOP"); + return null; + } + + @Override + protected Void visitWhileStatement(WhileStatement node, Integer indent) + { + builder.append(indentString(indent)); + appendBeginLabel(node.getLabel()); + builder.append("WHILE ") + .append(formatExpression(node.getExpression())) + .append(" DO\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + append(indent, "END WHILE"); + return null; + } + + @Override + protected Void visitRepeatStatement(RepeatStatement node, Integer indent) + { + builder.append(indentString(indent)); + appendBeginLabel(node.getLabel()); + builder.append("REPEAT\n"); + for (ControlStatement statement : node.getStatements()) { + process(statement, indent + 1); + builder.append(";\n"); + } + append(indent, "UNTIL ") + .append(formatExpression(node.getCondition())) + .append("\n"); + append(indent, "END REPEAT"); + return null; + } + + private void appendBeginLabel(Optional label) + { + label.ifPresent(value -> + builder.append(formatName(value)).append(": ")); + } + private void processRelation(Relation relation, Integer indent) { // TODO: handle this properly @@ -2229,6 +2563,19 @@ private void processRelation(Relation relation, Integer indent) } } + private void processParameters(List parameters, Integer indent) + { + builder.append("("); + Iterator iterator = parameters.iterator(); + while (iterator.hasNext()) { + process(iterator.next(), indent); + if (iterator.hasNext()) { + builder.append(", "); + } + } + builder.append(")"); + } + private SqlBuilder append(int indent, String value) { return builder.append(indentString(indent)) diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index cc4e13ac5059..b2581d57cfb8 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -31,6 +31,7 @@ import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.Array; +import io.trino.sql.tree.AssignmentStatement; import io.trino.sql.tree.AtTimeZone; import io.trino.sql.tree.BetweenPredicate; import io.trino.sql.tree.BinaryLiteral; @@ -38,14 +39,20 @@ import io.trino.sql.tree.BooleanLiteral; import io.trino.sql.tree.Call; import io.trino.sql.tree.CallArgument; +import io.trino.sql.tree.CaseStatement; +import io.trino.sql.tree.CaseStatementWhenClause; import io.trino.sql.tree.Cast; import io.trino.sql.tree.CharLiteral; import io.trino.sql.tree.CoalesceExpression; import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.Comment; +import io.trino.sql.tree.CommentCharacteristic; import io.trino.sql.tree.Commit; import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.CompoundStatement; +import io.trino.sql.tree.ControlStatement; import io.trino.sql.tree.CreateCatalog; +import io.trino.sql.tree.CreateFunction; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateRole; import io.trino.sql.tree.CreateSchema; @@ -69,14 +76,18 @@ import io.trino.sql.tree.DescribeOutput; import io.trino.sql.tree.Descriptor; import io.trino.sql.tree.DescriptorField; +import io.trino.sql.tree.DeterministicCharacteristic; import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.DropCatalog; import io.trino.sql.tree.DropColumn; +import io.trino.sql.tree.DropFunction; import io.trino.sql.tree.DropMaterializedView; import io.trino.sql.tree.DropRole; import io.trino.sql.tree.DropSchema; import io.trino.sql.tree.DropTable; import io.trino.sql.tree.DropView; +import io.trino.sql.tree.ElseClause; +import io.trino.sql.tree.ElseIfClause; import io.trino.sql.tree.EmptyPattern; import io.trino.sql.tree.EmptyTableTreatment; import io.trino.sql.tree.EmptyTableTreatment.Treatment; @@ -97,6 +108,7 @@ import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.FunctionCall.NullTreatment; +import io.trino.sql.tree.FunctionSpecification; import io.trino.sql.tree.GenericDataType; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.Grant; @@ -109,6 +121,7 @@ import io.trino.sql.tree.GroupingSets; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.IfExpression; +import io.trino.sql.tree.IfStatement; import io.trino.sql.tree.InListExpression; import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.Insert; @@ -118,6 +131,7 @@ import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.Isolation; +import io.trino.sql.tree.IterateStatement; import io.trino.sql.tree.Join; import io.trino.sql.tree.JoinCriteria; import io.trino.sql.tree.JoinOn; @@ -141,12 +155,15 @@ import io.trino.sql.tree.JsonValue; import io.trino.sql.tree.LambdaArgumentDeclaration; import io.trino.sql.tree.LambdaExpression; +import io.trino.sql.tree.LanguageCharacteristic; import io.trino.sql.tree.Lateral; +import io.trino.sql.tree.LeaveStatement; import io.trino.sql.tree.LikeClause; import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.Limit; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.LoopStatement; import io.trino.sql.tree.MeasureDefinition; import io.trino.sql.tree.Merge; import io.trino.sql.tree.MergeCase; @@ -159,6 +176,7 @@ import io.trino.sql.tree.NodeLocation; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullIfExpression; +import io.trino.sql.tree.NullInputCharacteristic; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.NumericParameter; import io.trino.sql.tree.Offset; @@ -166,6 +184,7 @@ import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.OrdinalityColumn; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.ParameterDeclaration; import io.trino.sql.tree.PathElement; import io.trino.sql.tree.PathSpecification; import io.trino.sql.tree.PatternAlternation; @@ -199,17 +218,22 @@ import io.trino.sql.tree.RenameSchema; import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; +import io.trino.sql.tree.RepeatStatement; import io.trino.sql.tree.ResetSession; import io.trino.sql.tree.ResetSessionAuthorization; +import io.trino.sql.tree.ReturnStatement; +import io.trino.sql.tree.ReturnsClause; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; +import io.trino.sql.tree.RoutineCharacteristic; import io.trino.sql.tree.Row; import io.trino.sql.tree.RowDataType; import io.trino.sql.tree.RowPattern; import io.trino.sql.tree.SampledRelation; import io.trino.sql.tree.SaveMode; import io.trino.sql.tree.SearchedCaseExpression; +import io.trino.sql.tree.SecurityCharacteristic; import io.trino.sql.tree.Select; import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.SetColumnType; @@ -266,8 +290,10 @@ import io.trino.sql.tree.Use; import io.trino.sql.tree.ValueColumn; import io.trino.sql.tree.Values; +import io.trino.sql.tree.VariableDeclaration; import io.trino.sql.tree.VariableDefinition; import io.trino.sql.tree.WhenClause; +import io.trino.sql.tree.WhileStatement; import io.trino.sql.tree.Window; import io.trino.sql.tree.WindowDefinition; import io.trino.sql.tree.WindowFrame; @@ -380,6 +406,12 @@ public Node visitStandaloneRowPattern(SqlBaseParser.StandaloneRowPatternContext return visit(context.rowPattern()); } + @Override + public Node visitStandaloneFunctionSpecification(SqlBaseParser.StandaloneFunctionSpecificationContext context) + { + return visit(context.functionSpecification()); + } + // ******************* statements ********************** @Override @@ -520,7 +552,7 @@ public Node visitCreateTableAsSelect(SqlBaseParser.CreateTableAsSelectContext co return new CreateTableAsSelect( getLocation(context), getQualifiedName(context.qualifiedName()), - (Query) visit(context.query()), + (Query) visit(context.rootQuery()), toSaveMode(context.REPLACE(), context.EXISTS()), properties, context.NO() == null, @@ -572,7 +604,7 @@ public Node visitCreateMaterializedView(SqlBaseParser.CreateMaterializedViewCont return new CreateMaterializedView( Optional.of(getLocation(context)), getQualifiedName(context.qualifiedName()), - (Query) visit(context.query()), + (Query) visit(context.rootQuery()), context.REPLACE() != null, context.EXISTS() != null, gracePeriod, @@ -624,7 +656,7 @@ public Node visitInsertInto(SqlBaseParser.InsertIntoContext context) return new Insert( new Table(getQualifiedName(context.qualifiedName())), columnAliases, - (Query) visit(context.query())); + (Query) visit(context.rootQuery())); } @Override @@ -856,7 +888,7 @@ else if (context.INVOKER() != null) { return new CreateView( getLocation(context), getQualifiedName(context.qualifiedName()), - (Query) visit(context.query()), + (Query) visit(context.rootQuery()), context.REPLACE() != null, comment, security); @@ -893,6 +925,25 @@ public Node visitSetMaterializedViewProperties(SqlBaseParser.SetMaterializedView visit(context.propertyAssignments().property(), Property.class)); } + @Override + public Node visitCreateFunction(SqlBaseParser.CreateFunctionContext context) + { + return new CreateFunction( + getLocation(context), + (FunctionSpecification) visit(context.functionSpecification()), + context.REPLACE() != null); + } + + @Override + public Node visitDropFunction(SqlBaseParser.DropFunctionContext context) + { + return new DropFunction( + getLocation(context), + getQualifiedName(context.functionDeclaration().qualifiedName()), + visit(context.functionDeclaration().parameterDeclaration(), ParameterDeclaration.class), + context.EXISTS() != null); + } + @Override public Node visitStartTransaction(SqlBaseParser.StartTransactionContext context) { @@ -1022,6 +1073,24 @@ public Node visitProperty(SqlBaseParser.PropertyContext context) // ********************** query expressions ******************** + @Override + public Node visitRootQuery(SqlBaseParser.RootQueryContext context) + { + Query query = (Query) visit(context.query()); + + return new Query( + getLocation(context), + Optional.ofNullable(context.withFunction()) + .map(SqlBaseParser.WithFunctionContext::functionSpecification) + .map(contexts -> visit(contexts, FunctionSpecification.class)) + .orElseGet(ImmutableList::of), + query.getWith(), + query.getQueryBody(), + query.getOrderBy(), + query.getOffset(), + query.getLimit()); + } + @Override public Node visitQuery(SqlBaseParser.QueryContext context) { @@ -1029,6 +1098,7 @@ public Node visitQuery(SqlBaseParser.QueryContext context) return new Query( getLocation(context), + ImmutableList.of(), visitIfPresent(context.with(), With.class), body.getQueryBody(), body.getOrderBy(), @@ -1124,6 +1194,7 @@ else if (context.limit.rowCount().INTEGER_VALUE() != null) { return new Query( getLocation(context), + ImmutableList.of(), Optional.empty(), new QuerySpecification( getLocation(context), @@ -1143,6 +1214,7 @@ else if (context.limit.rowCount().INTEGER_VALUE() != null) { return new Query( getLocation(context), + ImmutableList.of(), Optional.empty(), term, orderBy, @@ -1391,7 +1463,7 @@ public Node visitShowStats(SqlBaseParser.ShowStatsContext context) @Override public Node visitShowStatsForQuery(SqlBaseParser.ShowStatsForQueryContext context) { - Query query = (Query) visit(context.query()); + Query query = (Query) visit(context.rootQuery()); return new ShowStats(Optional.of(getLocation(context)), new TableSubquery(query)); } @@ -3647,6 +3719,215 @@ public Node visitJsonTableDefaultPlan(SqlBaseParser.JsonTableDefaultPlanContext return new JsonTableDefaultPlan(getLocation(context), parentChildPlanType, siblingsPlanType); } + // ***************** functions & stored procedures ***************** + + @Override + public Node visitFunctionSpecification(SqlBaseParser.FunctionSpecificationContext context) + { + ControlStatement statement = (ControlStatement) visit(context.controlStatement()); + if (!(statement instanceof ReturnStatement || statement instanceof CompoundStatement)) { + throw parseError("Function body must start with RETURN or BEGIN", context.controlStatement()); + } + return new FunctionSpecification( + getLocation(context), + getQualifiedName(context.functionDeclaration().qualifiedName()), + visit(context.functionDeclaration().parameterDeclaration(), ParameterDeclaration.class), + (ReturnsClause) visit(context.returnsClause()), + visit(context.routineCharacteristic(), RoutineCharacteristic.class), + statement); + } + + @Override + public Node visitParameterDeclaration(SqlBaseParser.ParameterDeclarationContext context) + { + return new ParameterDeclaration( + getLocation(context), + getIdentifierIfPresent(context.identifier()), + (DataType) visit(context.type())); + } + + @Override + public Node visitReturnsClause(SqlBaseParser.ReturnsClauseContext context) + { + return new ReturnsClause(getLocation(context), (DataType) visit(context.type())); + } + + @Override + public Node visitLanguageCharacteristic(SqlBaseParser.LanguageCharacteristicContext context) + { + return new LanguageCharacteristic(getLocation(context), (Identifier) visit(context.identifier())); + } + + @Override + public Node visitDeterministicCharacteristic(SqlBaseParser.DeterministicCharacteristicContext context) + { + return new DeterministicCharacteristic(getLocation(context), context.NOT() == null); + } + + @Override + public Node visitReturnsNullOnNullInputCharacteristic(SqlBaseParser.ReturnsNullOnNullInputCharacteristicContext context) + { + return NullInputCharacteristic.returnsNullOnNullInput(getLocation(context)); + } + + @Override + public Node visitCalledOnNullInputCharacteristic(SqlBaseParser.CalledOnNullInputCharacteristicContext context) + { + return NullInputCharacteristic.calledOnNullInput(getLocation(context)); + } + + @Override + public Node visitSecurityCharacteristic(SqlBaseParser.SecurityCharacteristicContext context) + { + return new SecurityCharacteristic(getLocation(context), (context.INVOKER() != null) + ? SecurityCharacteristic.Security.INVOKER + : SecurityCharacteristic.Security.DEFINER); + } + + @Override + public Node visitCommentCharacteristic(SqlBaseParser.CommentCharacteristicContext context) + { + return new CommentCharacteristic(getLocation(context), ((StringLiteral) visit(context.string())).getValue()); + } + + @Override + public Node visitReturnStatement(SqlBaseParser.ReturnStatementContext context) + { + return new ReturnStatement(getLocation(context), (Expression) visit(context.valueExpression())); + } + + @Override + public Node visitAssignmentStatement(SqlBaseParser.AssignmentStatementContext context) + { + return new AssignmentStatement( + getLocation(context), + (Identifier) visit(context.identifier()), + (Expression) visit(context.expression())); + } + + @Override + public Node visitSimpleCaseStatement(SqlBaseParser.SimpleCaseStatementContext context) + { + return new CaseStatement( + getLocation(context), + visitIfPresent(context.expression(), Expression.class), + visit(context.caseStatementWhenClause(), CaseStatementWhenClause.class), + visitIfPresent(context.elseClause(), ElseClause.class)); + } + + @Override + public Node visitSearchedCaseStatement(SqlBaseParser.SearchedCaseStatementContext context) + { + return new CaseStatement( + getLocation(context), + Optional.empty(), + visit(context.caseStatementWhenClause(), CaseStatementWhenClause.class), + visitIfPresent(context.elseClause(), ElseClause.class)); + } + + @Override + public Node visitCaseStatementWhenClause(SqlBaseParser.CaseStatementWhenClauseContext context) + { + return new CaseStatementWhenClause( + getLocation(context), + (Expression) visit(context.expression()), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class)); + } + + @Override + public Node visitIfStatement(SqlBaseParser.IfStatementContext context) + { + return new IfStatement( + getLocation(context), + (Expression) visit(context.expression()), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class), + visit(context.elseIfClause(), ElseIfClause.class), + visitIfPresent(context.elseClause(), ElseClause.class)); + } + + @Override + public Node visitElseIfClause(SqlBaseParser.ElseIfClauseContext context) + { + return new ElseIfClause( + getLocation(context), + (Expression) visit(context.expression()), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class)); + } + + @Override + public Node visitElseClause(SqlBaseParser.ElseClauseContext context) + { + return new ElseClause( + getLocation(context), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class)); + } + + @Override + public Node visitIterateStatement(SqlBaseParser.IterateStatementContext context) + { + return new IterateStatement( + getLocation(context), + (Identifier) visit(context.identifier())); + } + + @Override + public Node visitLeaveStatement(SqlBaseParser.LeaveStatementContext context) + { + return new LeaveStatement( + getLocation(context), + (Identifier) visit(context.identifier())); + } + + @Override + public Node visitVariableDeclaration(SqlBaseParser.VariableDeclarationContext context) + { + return new VariableDeclaration( + getLocation(context), + visit(context.identifier(), Identifier.class), + (DataType) visit(context.type()), + visitIfPresent(context.valueExpression(), Expression.class)); + } + + @Override + public Node visitCompoundStatement(SqlBaseParser.CompoundStatementContext context) + { + return new CompoundStatement( + getLocation(context), + visit(context.variableDeclaration(), VariableDeclaration.class), + visit(Optional.ofNullable(context.sqlStatementList()) + .map(SqlBaseParser.SqlStatementListContext::controlStatement) + .orElse(ImmutableList.of()), ControlStatement.class)); + } + + @Override + public Node visitLoopStatement(SqlBaseParser.LoopStatementContext context) + { + return new LoopStatement( + getLocation(context), + getIdentifierIfPresent(context.label), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class)); + } + + @Override + public Node visitWhileStatement(SqlBaseParser.WhileStatementContext context) + { + return new WhileStatement( + getLocation(context), + getIdentifierIfPresent(context.label), + (Expression) visit(context.expression()), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class)); + } + + @Override + public Node visitRepeatStatement(SqlBaseParser.RepeatStatementContext context) + { + return new RepeatStatement( + getLocation(context), + getIdentifierIfPresent(context.label), + visit(context.sqlStatementList().controlStatement(), ControlStatement.class), + (Expression) visit(context.expression())); + } + // ***************** helpers ***************** @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java b/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java index cbb534c0ef52..9c4507093d9d 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java @@ -18,6 +18,7 @@ import io.trino.grammar.sql.SqlBaseParser; import io.trino.sql.tree.DataType; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionSpecification; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; import io.trino.sql.tree.PathSpecification; @@ -114,6 +115,11 @@ public RowPattern createRowPattern(String pattern) return (RowPattern) invokeParser("row pattern", pattern, SqlBaseParser::standaloneRowPattern); } + public FunctionSpecification createFunctionSpecification(String sql) + { + return (FunctionSpecification) invokeParser("function specification", sql, SqlBaseParser::standaloneFunctionSpecification); + } + private Node invokeParser(String name, String sql, Function parseFunction) { return invokeParser(name, sql, Optional.empty(), parseFunction); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AssignmentStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AssignmentStatement.java new file mode 100644 index 000000000000..6a0d5afd1653 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AssignmentStatement.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class AssignmentStatement + extends ControlStatement +{ + private final Identifier target; + private final Expression value; + + public AssignmentStatement(NodeLocation location, Identifier target, Expression value) + { + super(location); + this.target = requireNonNull(target, "target is null"); + this.value = requireNonNull(value, "value is null"); + } + + public Identifier getTarget() + { + return target; + } + + public Expression getValue() + { + return value; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitAssignmentStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(value, target); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof AssignmentStatement other) && + Objects.equals(target, other.target) && + Objects.equals(value, other.value); + } + + @Override + public int hashCode() + { + return Objects.hash(target, value); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("target", target) + .add("value", value) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 2f975f54ee16..272d6633dadd 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -1226,4 +1226,124 @@ protected R visitJsonTableDefaultPlan(JsonTableDefaultPlan node, C context) { return visitNode(node, context); } + + protected R visitCreateFunction(CreateFunction node, C context) + { + return visitStatement(node, context); + } + + protected R visitDropFunction(DropFunction node, C context) + { + return visitStatement(node, context); + } + + protected R visitFunctionSpecification(FunctionSpecification node, C context) + { + return visitNode(node, context); + } + + protected R visitParameterDeclaration(ParameterDeclaration node, C context) + { + return visitNode(node, context); + } + + protected R visitReturnClause(ReturnsClause node, C context) + { + return visitNode(node, context); + } + + protected R visitLanguageCharacteristic(LanguageCharacteristic node, C context) + { + return visitNode(node, context); + } + + protected R visitDeterministicCharacteristic(DeterministicCharacteristic node, C context) + { + return visitNode(node, context); + } + + protected R visitNullInputCharacteristic(NullInputCharacteristic node, C context) + { + return visitNode(node, context); + } + + protected R visitSecurityCharacteristic(SecurityCharacteristic node, C context) + { + return visitNode(node, context); + } + + protected R visitCommentCharacteristic(CommentCharacteristic node, C context) + { + return visitNode(node, context); + } + + protected R visitReturnStatement(ReturnStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitCompoundStatement(CompoundStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitVariableDeclaration(VariableDeclaration node, C context) + { + return visitNode(node, context); + } + + protected R visitAssignmentStatement(AssignmentStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitCaseStatement(CaseStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitCaseStatementWhenClause(CaseStatementWhenClause node, C context) + { + return visitNode(node, context); + } + + protected R visitIfStatement(IfStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitElseClause(ElseClause node, C context) + { + return visitNode(node, context); + } + + protected R visitElseIfClause(ElseIfClause node, C context) + { + return visitNode(node, context); + } + + protected R visitIterateStatement(IterateStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitLeaveStatement(LeaveStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitWhileStatement(WhileStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitLoopStatement(LoopStatement node, C context) + { + return visitNode(node, context); + } + + protected R visitRepeatStatement(RepeatStatement node, C context) + { + return visitNode(node, context); + } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatement.java new file mode 100644 index 000000000000..69a7a02fe271 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatement.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class CaseStatement + extends ControlStatement +{ + private final Optional expression; + private final List whenClauses; + private final Optional elseClause; + + public CaseStatement( + NodeLocation location, + Optional expression, + List whenClauses, + Optional elseClause) + { + super(location); + this.expression = requireNonNull(expression, "expression is null"); + this.whenClauses = requireNonNull(whenClauses, "whenClauses is null"); + this.elseClause = requireNonNull(elseClause, "elseClause is null"); + } + + public Optional getExpression() + { + return expression; + } + + public List getWhenClauses() + { + return whenClauses; + } + + public Optional getElseClause() + { + return elseClause; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitCaseStatement(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder children = ImmutableList.builder(); + expression.ifPresent(children::add); + children.addAll(whenClauses); + elseClause.ifPresent(children::add); + return children.build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof CaseStatement other) && + Objects.equals(expression, other.expression) && + Objects.equals(whenClauses, other.whenClauses) && + Objects.equals(elseClause, other.elseClause); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, whenClauses, elseClause); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .add("whenClauses", whenClauses) + .add("elseClause", elseClause) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatementWhenClause.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatementWhenClause.java new file mode 100644 index 000000000000..29b9fd5d6762 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CaseStatementWhenClause.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class CaseStatementWhenClause + extends Node +{ + private final Expression expression; + private final List statements; + + public CaseStatementWhenClause(Expression expression, List statements) + { + this(Optional.empty(), expression, statements); + } + + public CaseStatementWhenClause(NodeLocation location, Expression expression, List statements) + { + this(Optional.of(location), expression, statements); + } + + private CaseStatementWhenClause(Optional location, Expression expression, List statements) + { + super(location); + this.expression = requireNonNull(expression, "expression is null"); + this.statements = requireNonNull(statements, "statements is null"); + } + + public Expression getExpression() + { + return expression; + } + + public List getStatements() + { + return statements; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitCaseStatementWhenClause(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .add(expression) + .addAll(statements) + .build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof CaseStatementWhenClause other) && + Objects.equals(expression, other.expression) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, statements); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CommentCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CommentCharacteristic.java new file mode 100644 index 000000000000..df4cf9627a61 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CommentCharacteristic.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class CommentCharacteristic + extends RoutineCharacteristic +{ + private final String comment; + + public CommentCharacteristic(String comment) + { + this(Optional.empty(), comment); + } + + public CommentCharacteristic(NodeLocation location, String comment) + { + this(Optional.of(location), comment); + } + + private CommentCharacteristic(Optional location, String comment) + { + super(location); + this.comment = requireNonNull(comment, "comment is null"); + } + + public String getComment() + { + return comment; + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return visitor.visitCommentCharacteristic(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof CommentCharacteristic other) && + comment.equals(other.comment); + } + + @Override + public int hashCode() + { + return comment.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("comment", comment) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CompoundStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CompoundStatement.java new file mode 100644 index 000000000000..ee4ec9b555b7 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CompoundStatement.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class CompoundStatement + extends ControlStatement +{ + private final List variableDeclarations; + private final List statements; + + public CompoundStatement( + NodeLocation location, + List variableDeclarations, + List statements) + { + super(location); + this.variableDeclarations = requireNonNull(variableDeclarations, "variableDeclarations is null"); + this.statements = requireNonNull(statements, "statements is null"); + } + + public List getStatements() + { + return statements; + } + + public List getVariableDeclarations() + { + return variableDeclarations; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitCompoundStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .addAll(statements) + .addAll(variableDeclarations) + .build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof CompoundStatement other) && + Objects.equals(variableDeclarations, other.variableDeclarations) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(variableDeclarations, statements); + } + + @Override + + public String toString() + { + return toStringHelper(this) + .add("variableDeclarations", variableDeclarations) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ControlStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ControlStatement.java new file mode 100644 index 000000000000..ae69ecb5d057 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ControlStatement.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.Optional; + +public abstract sealed class ControlStatement + extends Node + permits AssignmentStatement, CaseStatement, CompoundStatement, + IfStatement, IterateStatement, LeaveStatement, LoopStatement, + RepeatStatement, ReturnStatement, VariableDeclaration, WhileStatement +{ + protected ControlStatement(NodeLocation location) + { + super(Optional.of(location)); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CreateFunction.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateFunction.java new file mode 100644 index 000000000000..3018cdea951b --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateFunction.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class CreateFunction + extends Statement +{ + private final FunctionSpecification specification; + private final boolean replace; + + public CreateFunction(FunctionSpecification specification, boolean replace) + { + this(Optional.empty(), specification, replace); + } + + public CreateFunction(NodeLocation location, FunctionSpecification specification, boolean replace) + { + this(Optional.of(location), specification, replace); + } + + private CreateFunction(Optional location, FunctionSpecification specification, boolean replace) + { + super(location); + this.specification = requireNonNull(specification, "specification is null"); + this.replace = replace; + } + + public FunctionSpecification getSpecification() + { + return specification; + } + + public boolean isReplace() + { + return replace; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitCreateFunction(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(specification); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof CreateFunction other) && + Objects.equals(specification, other.specification) && + Objects.equals(replace, other.replace); + } + + @Override + public int hashCode() + { + return Objects.hash(specification, replace); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("specification", specification) + .add("replace", replace) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DeterministicCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DeterministicCharacteristic.java new file mode 100644 index 000000000000..790b0c892adc --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DeterministicCharacteristic.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public final class DeterministicCharacteristic + extends RoutineCharacteristic +{ + private final boolean deterministic; + + public DeterministicCharacteristic(boolean deterministic) + { + this(Optional.empty(), deterministic); + } + + public DeterministicCharacteristic(NodeLocation location, boolean deterministic) + { + this(Optional.of(location), deterministic); + } + + private DeterministicCharacteristic(Optional location, boolean deterministic) + { + super(location); + this.deterministic = deterministic; + } + + public boolean isDeterministic() + { + return deterministic; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDeterministicCharacteristic(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof DeterministicCharacteristic other) && + (deterministic == other.deterministic); + } + + @Override + public int hashCode() + { + return Objects.hash(deterministic); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("deterministic", deterministic) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DropFunction.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DropFunction.java new file mode 100644 index 000000000000..5587f6e68ad2 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DropFunction.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class DropFunction + extends Statement +{ + private final QualifiedName name; + private final List parameters; + private final boolean exists; + + public DropFunction(QualifiedName name, List parameters, boolean exists) + { + this(Optional.empty(), name, parameters, exists); + } + + public DropFunction(NodeLocation location, QualifiedName name, List parameters, boolean exists) + { + this(Optional.of(location), name, parameters, exists); + } + + private DropFunction(Optional location, QualifiedName name, List parameters, boolean exists) + { + super(location); + this.name = requireNonNull(name, "name is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.exists = exists; + } + + public QualifiedName getName() + { + return name; + } + + public List getParameters() + { + return parameters; + } + + public boolean isExists() + { + return exists; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDropFunction(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public int hashCode() + { + return Objects.hash(name, exists); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + DropFunction o = (DropFunction) obj; + return Objects.equals(name, o.name) + && (exists == o.exists); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("exists", exists) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ElseClause.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ElseClause.java new file mode 100644 index 000000000000..68e1f8b656e2 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ElseClause.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ElseClause + extends Node +{ + private final List statements; + + public ElseClause(List statements) + { + this(Optional.empty(), statements); + } + + public ElseClause(NodeLocation location, List statements) + { + this(Optional.of(location), statements); + } + + private ElseClause(Optional location, List statements) + { + super(location); + this.statements = requireNonNull(statements, "statements is null"); + } + + public List getStatements() + { + return statements; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitElseClause(this, context); + } + + @Override + public List getChildren() + { + return statements; + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof ElseClause other) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(statements); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ElseIfClause.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ElseIfClause.java new file mode 100644 index 000000000000..853742f63e6c --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ElseIfClause.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ElseIfClause + extends Node +{ + private final Expression expression; + private final List statements; + + public ElseIfClause(Expression expression, List statements) + { + this(Optional.empty(), expression, statements); + } + + public ElseIfClause(NodeLocation location, Expression expression, List statements) + { + this(Optional.of(location), expression, statements); + } + + private ElseIfClause(Optional location, Expression expression, List statements) + { + super(location); + this.expression = requireNonNull(expression, "expression is null"); + this.statements = requireNonNull(statements, "statements is null"); + } + + public Expression getExpression() + { + return expression; + } + + public List getStatements() + { + return statements; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitElseIfClause(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .add(expression) + .addAll(statements) + .build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof ElseIfClause other) && + Objects.equals(expression, other.expression) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, statements); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/FunctionSpecification.java b/core/trino-parser/src/main/java/io/trino/sql/tree/FunctionSpecification.java new file mode 100644 index 000000000000..6c6339f3247d --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/FunctionSpecification.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class FunctionSpecification + extends Node +{ + private final QualifiedName name; + private final List parameters; + private final ReturnsClause returnsClause; + private final List routineCharacteristics; + private final ControlStatement statement; + + public FunctionSpecification( + QualifiedName name, + List parameters, + ReturnsClause returnsClause, + List routineCharacteristics, + ControlStatement statement) + { + this(Optional.empty(), name, parameters, returnsClause, routineCharacteristics, statement); + } + + public FunctionSpecification( + NodeLocation location, + QualifiedName name, + List parameters, + ReturnsClause returnsClause, + List routineCharacteristics, + ControlStatement statement) + { + this(Optional.of(location), name, parameters, returnsClause, routineCharacteristics, statement); + } + + private FunctionSpecification( + Optional location, + QualifiedName name, + List parameters, + ReturnsClause returnsClause, + List routineCharacteristics, + ControlStatement statement) + { + super(location); + this.name = requireNonNull(name, "name is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.returnsClause = requireNonNull(returnsClause, "returnClause is null"); + this.routineCharacteristics = ImmutableList.copyOf(requireNonNull(routineCharacteristics, "routineCharacteristics is null")); + this.statement = requireNonNull(statement, "statement is null"); + } + + public QualifiedName getName() + { + return name; + } + + public List getParameters() + { + return parameters; + } + + public ReturnsClause getReturnsClause() + { + return returnsClause; + } + + public List getRoutineCharacteristics() + { + return routineCharacteristics; + } + + public ControlStatement getStatement() + { + return statement; + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .addAll(parameters) + .add(returnsClause) + .addAll(routineCharacteristics) + .add(statement) + .build(); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitFunctionSpecification(this, context); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof FunctionSpecification other) && + Objects.equals(name, other.name) && + Objects.equals(parameters, other.parameters) && + Objects.equals(returnsClause, other.returnsClause) && + Objects.equals(routineCharacteristics, other.routineCharacteristics) && + Objects.equals(statement, other.statement); + } + + @Override + public int hashCode() + { + return Objects.hash(name, parameters, returnsClause, routineCharacteristics, statement); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("parameters", parameters) + .add("returnsClause", returnsClause) + .add("routineCharacteristics", routineCharacteristics) + .add("statement", statement) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/IfStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/IfStatement.java new file mode 100644 index 000000000000..3f03ea3b0afc --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/IfStatement.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class IfStatement + extends ControlStatement +{ + private final Expression expression; + private final List statements; + private final List elseIfClauses; + private final Optional elseClause; + + public IfStatement( + NodeLocation location, + Expression expression, + List statements, + List elseIfClauses, + Optional elseClause) + { + super(location); + this.expression = requireNonNull(expression, "expression is null"); + this.statements = requireNonNull(statements, "statements is null"); + this.elseIfClauses = requireNonNull(elseIfClauses, "elseIfClauses is null"); + this.elseClause = requireNonNull(elseClause, "elseClause is null"); + } + + public Expression getExpression() + { + return expression; + } + + public List getStatements() + { + return statements; + } + + public List getElseIfClauses() + { + return elseIfClauses; + } + + public Optional getElseClause() + { + return elseClause; + } + + @Override + public List getChildren() + { + ImmutableList.Builder children = ImmutableList.builder() + .add(expression) + .addAll(statements) + .addAll(elseIfClauses); + elseClause.ifPresent(children::add); + return children.build(); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitIfStatement(this, context); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof IfStatement other) && + Objects.equals(expression, other.expression) && + Objects.equals(statements, other.statements) && + Objects.equals(elseIfClauses, other.elseIfClauses) && + Objects.equals(elseClause, other.elseClause); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, statements, elseIfClauses, elseClause); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("expression", expression) + .add("statements", statements) + .add("elseIfClauses", elseIfClauses) + .add("elseClause", elseClause) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/IterateStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/IterateStatement.java new file mode 100644 index 000000000000..e49c6052c3b3 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/IterateStatement.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class IterateStatement + extends ControlStatement +{ + private final Identifier label; + + public IterateStatement(NodeLocation location, Identifier label) + { + super(location); + this.label = requireNonNull(label, "label is null"); + } + + public Identifier getLabel() + { + return label; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitIterateStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IterateStatement that = (IterateStatement) o; + return Objects.equals(label, that.label); + } + + @Override + public int hashCode() + { + return Objects.hash(label); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("label", label) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/LanguageCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/LanguageCharacteristic.java new file mode 100644 index 000000000000..b89cc1fde7ec --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/LanguageCharacteristic.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class LanguageCharacteristic + extends RoutineCharacteristic +{ + private final Identifier language; + + public LanguageCharacteristic(Identifier language) + { + this(Optional.empty(), language); + } + + public LanguageCharacteristic(NodeLocation location, Identifier language) + { + this(Optional.of(location), language); + } + + private LanguageCharacteristic(Optional location, Identifier language) + { + super(location); + this.language = requireNonNull(language, "comment is null"); + } + + public Identifier getLanguage() + { + return language; + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return visitor.visitLanguageCharacteristic(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof LanguageCharacteristic other) && + language.equals(other.language); + } + + @Override + public int hashCode() + { + return language.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("language", language) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/LeaveStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/LeaveStatement.java new file mode 100644 index 000000000000..041753c5eb0b --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/LeaveStatement.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class LeaveStatement + extends ControlStatement +{ + private final Identifier label; + + public LeaveStatement(NodeLocation location, Identifier label) + { + super(location); + this.label = requireNonNull(label, "label is null"); + } + + public Identifier getLabel() + { + return label; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitLeaveStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof LeaveStatement other) && + Objects.equals(label, other.label); + } + + @Override + public int hashCode() + { + return Objects.hash(label); + } + + @Override + + public String toString() + { + return toStringHelper(this) + .add("label", label) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/LoopStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/LoopStatement.java new file mode 100644 index 000000000000..08cdc000c667 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/LoopStatement.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class LoopStatement + extends ControlStatement +{ + private final Optional label; + private final List statements; + + public LoopStatement(NodeLocation location, Optional label, List statements) + { + super(location); + this.label = requireNonNull(label, "label is null"); + this.statements = requireNonNull(statements, "statements is null"); + } + + public Optional getLabel() + { + return label; + } + + public List getStatements() + { + return statements; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitLoopStatement(this, context); + } + + @Override + public List getChildren() + { + return statements; + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof LoopStatement other) && + Objects.equals(label, other.label) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(label, statements); + } + + @Override + + public String toString() + { + return toStringHelper(this) + .add("label", label) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/NullInputCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/NullInputCharacteristic.java new file mode 100644 index 000000000000..ab69f12bf2c1 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/NullInputCharacteristic.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public final class NullInputCharacteristic + extends RoutineCharacteristic +{ + public static NullInputCharacteristic returnsNullOnNullInput() + { + return new NullInputCharacteristic(Optional.empty(), false); + } + + public static NullInputCharacteristic returnsNullOnNullInput(NodeLocation location) + { + return new NullInputCharacteristic(Optional.of(location), false); + } + + public static NullInputCharacteristic calledOnNullInput() + { + return new NullInputCharacteristic(Optional.empty(), true); + } + + public static NullInputCharacteristic calledOnNullInput(NodeLocation location) + { + return new NullInputCharacteristic(Optional.of(location), true); + } + + private final boolean calledOnNull; + + private NullInputCharacteristic(Optional location, boolean calledOnNull) + { + super(location); + this.calledOnNull = calledOnNull; + } + + public boolean isCalledOnNull() + { + return calledOnNull; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitNullInputCharacteristic(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof NullInputCharacteristic other) && + (calledOnNull == other.calledOnNull); + } + + @Override + public int hashCode() + { + return Objects.hash(calledOnNull); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("calledOnNull", calledOnNull) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ParameterDeclaration.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ParameterDeclaration.java new file mode 100644 index 000000000000..afccf350d47c --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ParameterDeclaration.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ParameterDeclaration + extends Node +{ + private final Optional name; + private final DataType type; + + public ParameterDeclaration(Optional name, DataType type) + { + this(Optional.empty(), name, type); + } + + public ParameterDeclaration(NodeLocation location, Optional name, DataType type) + { + this(Optional.of(location), name, type); + } + + private ParameterDeclaration(Optional location, Optional name, DataType type) + { + super(location); + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + } + + public Optional getName() + { + return name; + } + + public DataType getType() + { + return type; + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitParameterDeclaration(this, context); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof ParameterDeclaration other) && + Objects.equals(name, other.name) && + Objects.equals(type, other.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("type", type) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Query.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Query.java index 19afd477eb82..31f8da6698bf 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Query.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Query.java @@ -26,6 +26,7 @@ public class Query extends Statement { + private final List functions; private final Optional with; private final QueryBody queryBody; private final Optional orderBy; @@ -33,28 +34,31 @@ public class Query private final Optional limit; public Query( + List functions, Optional with, QueryBody queryBody, Optional orderBy, Optional offset, Optional limit) { - this(Optional.empty(), with, queryBody, orderBy, offset, limit); + this(Optional.empty(), functions, with, queryBody, orderBy, offset, limit); } public Query( NodeLocation location, + List functions, Optional with, QueryBody queryBody, Optional orderBy, Optional offset, Optional limit) { - this(Optional.of(location), with, queryBody, orderBy, offset, limit); + this(Optional.of(location), functions, with, queryBody, orderBy, offset, limit); } private Query( Optional location, + List functions, Optional with, QueryBody queryBody, Optional orderBy, @@ -62,6 +66,7 @@ private Query( Optional limit) { super(location); + requireNonNull(functions, "function si snull"); requireNonNull(with, "with is null"); requireNonNull(queryBody, "queryBody is null"); requireNonNull(orderBy, "orderBy is null"); @@ -69,6 +74,7 @@ private Query( requireNonNull(limit, "limit is null"); checkArgument(!limit.isPresent() || limit.get() instanceof FetchFirst || limit.get() instanceof Limit, "limit must be optional of either FetchFirst or Limit type"); + this.functions = ImmutableList.copyOf(functions); this.with = with; this.queryBody = queryBody; this.orderBy = orderBy; @@ -76,6 +82,11 @@ private Query( this.limit = limit; } + public List getFunctions() + { + return functions; + } + public Optional getWith() { return with; @@ -111,6 +122,7 @@ public R accept(AstVisitor visitor, C context) public List getChildren() { ImmutableList.Builder nodes = ImmutableList.builder(); + nodes.addAll(functions); with.ifPresent(nodes::add); nodes.add(queryBody); orderBy.ifPresent(nodes::add); @@ -123,6 +135,7 @@ public List getChildren() public String toString() { return toStringHelper(this) + .add("functions", functions.isEmpty() ? null : functions) .add("with", with.orElse(null)) .add("queryBody", queryBody) .add("orderBy", orderBy) @@ -142,7 +155,8 @@ public boolean equals(Object obj) return false; } Query o = (Query) obj; - return Objects.equals(with, o.with) && + return Objects.equals(functions, o.functions) && + Objects.equals(with, o.with) && Objects.equals(queryBody, o.queryBody) && Objects.equals(orderBy, o.orderBy) && Objects.equals(offset, o.offset) && @@ -152,7 +166,7 @@ public boolean equals(Object obj) @Override public int hashCode() { - return Objects.hash(with, queryBody, orderBy, offset, limit); + return Objects.hash(functions, with, queryBody, orderBy, offset, limit); } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/RepeatStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/RepeatStatement.java new file mode 100644 index 000000000000..6dc0067710a5 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/RepeatStatement.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class RepeatStatement + extends ControlStatement +{ + private final Optional label; + private final List statements; + private final Expression condition; + + public RepeatStatement(NodeLocation location, Optional label, List statements, Expression condition) + { + super(location); + this.label = requireNonNull(label, "label is null"); + this.statements = requireNonNull(statements, "statements is null"); + this.condition = requireNonNull(condition, "condition is null"); + } + + public Optional getLabel() + { + return label; + } + + public List getStatements() + { + return statements; + } + + public Expression getCondition() + { + return condition; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitRepeatStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .addAll(statements) + .add(condition) + .build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof RepeatStatement other) && + Objects.equals(label, other.label) && + Objects.equals(statements, other.statements) && + Objects.equals(condition, other.condition); + } + + @Override + public int hashCode() + { + return Objects.hash(label, statements, condition); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("label", label) + .add("statements", statements) + .add("condition", condition) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnStatement.java new file mode 100644 index 000000000000..3b29b8d17c96 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnStatement.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ReturnStatement + extends ControlStatement +{ + private final Expression value; + + public ReturnStatement(NodeLocation location, Expression value) + { + super(location); + this.value = requireNonNull(value, "value is null"); + } + + public Expression getValue() + { + return value; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitReturnStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(value); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof ReturnStatement other) && + Objects.equals(value, other.value); + } + + @Override + public int hashCode() + { + return Objects.hash(value); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("value", value) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnsClause.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnsClause.java new file mode 100644 index 000000000000..03798ebabd22 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ReturnsClause.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ReturnsClause + extends Node +{ + private final DataType returnType; + + public ReturnsClause(NodeLocation location, DataType returnType) + { + super(Optional.of(location)); + this.returnType = requireNonNull(returnType, "returnType is null"); + } + + public DataType getReturnType() + { + return returnType; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitReturnClause(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof ReturnsClause other) && + returnType.equals(other.returnType); + } + + @Override + public int hashCode() + { + return returnType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("returnType", returnType) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/RoutineCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/RoutineCharacteristic.java new file mode 100644 index 000000000000..771428342d70 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/RoutineCharacteristic.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.Optional; + +public abstract sealed class RoutineCharacteristic + extends Node + permits CommentCharacteristic, DeterministicCharacteristic, LanguageCharacteristic, NullInputCharacteristic, SecurityCharacteristic +{ + protected RoutineCharacteristic(Optional location) + { + super(location); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/SecurityCharacteristic.java b/core/trino-parser/src/main/java/io/trino/sql/tree/SecurityCharacteristic.java new file mode 100644 index 000000000000..561d61913112 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/SecurityCharacteristic.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class SecurityCharacteristic + extends RoutineCharacteristic +{ + public enum Security + { + INVOKER, DEFINER + } + + private final Security security; + + public SecurityCharacteristic(Security security) + { + this(Optional.empty(), security); + } + + public SecurityCharacteristic(NodeLocation location, Security security) + { + this(Optional.of(location), security); + } + + private SecurityCharacteristic(Optional location, Security security) + { + super(location); + this.security = requireNonNull(security, "security is null"); + } + + public Security getSecurity() + { + return security; + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return visitor.visitSecurityCharacteristic(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof SecurityCharacteristic other) && + (security == other.security); + } + + @Override + public int hashCode() + { + return security.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("security", security) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/VariableDeclaration.java b/core/trino-parser/src/main/java/io/trino/sql/tree/VariableDeclaration.java new file mode 100644 index 000000000000..ab2c3dd7031d --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/VariableDeclaration.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class VariableDeclaration + extends ControlStatement +{ + private final List names; + private final DataType type; + private final Optional defaultValue; + + public VariableDeclaration(NodeLocation location, List names, DataType type, Optional defaultValue) + { + super(location); + this.names = requireNonNull(names, "name is null"); + this.type = requireNonNull(type, "type is null"); + this.defaultValue = requireNonNull(defaultValue, "defaultValue is null"); + } + + public List getNames() + { + return names; + } + + public DataType getType() + { + return type; + } + + public Optional getDefaultValue() + { + return defaultValue; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitVariableDeclaration(this, context); + } + + @Override + public List getChildren() + { + return defaultValue.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof VariableDeclaration other) && + Objects.equals(names, other.names) && + Objects.equals(type, other.type) && + Objects.equals(defaultValue, other.defaultValue); + } + + @Override + public int hashCode() + { + return Objects.hash(names, type, defaultValue); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("names", names) + .add("type", type) + .add("defaultValue", defaultValue) + .toString(); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/WhileStatement.java b/core/trino-parser/src/main/java/io/trino/sql/tree/WhileStatement.java new file mode 100644 index 000000000000..360d965b21aa --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/WhileStatement.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class WhileStatement + extends ControlStatement +{ + private final Optional label; + private final Expression expression; + private final List statements; + + public WhileStatement(NodeLocation location, Optional label, Expression expression, List statements) + { + super(location); + this.label = requireNonNull(label, "label is null"); + this.expression = requireNonNull(expression, "expression is null"); + this.statements = requireNonNull(statements, "statements is null"); + } + + public Optional getLabel() + { + return label; + } + + public Expression getExpression() + { + return expression; + } + + public List getStatements() + { + return statements; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitWhileStatement(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .add(expression) + .addAll(statements) + .build(); + } + + @Override + public boolean equals(Object obj) + { + return (obj instanceof WhileStatement other) && + Objects.equals(label, other.label) && + Objects.equals(expression, other.expression) && + Objects.equals(statements, other.statements); + } + + @Override + public int hashCode() + { + return Objects.hash(label, expression, statements); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("label", label) + .add("expression", expression) + .add("statements", statements) + .toString(); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/ParserAssert.java b/core/trino-parser/src/test/java/io/trino/sql/parser/ParserAssert.java index d55c23c009e4..56d10ef0301f 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/ParserAssert.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/ParserAssert.java @@ -65,6 +65,11 @@ public static AssertProvider rowPattern(String sql) return createAssertion(new SqlParser()::createRowPattern, sql); } + public static AssertProvider functionSpecification(String sql) + { + return createAssertion(new SqlParser()::createFunctionSpecification, sql); + } + private static Expression createExpression(String expression) { return new SqlParser().createExpression(expression); diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 79dfb4005699..4d922f76d7da 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -2026,6 +2026,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo AS SELECT * FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 21), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 21), @@ -2050,6 +2051,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo(x) AS SELECT a FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 24), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 24), @@ -2074,6 +2076,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo(x,y) AS SELECT a,b FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 26), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 26), @@ -2102,6 +2105,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE OR REPLACE TABLE foo AS SELECT * FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 25), "foo"), new Query( location(1, 32), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 32), @@ -2126,6 +2130,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE OR REPLACE TABLE foo(x) AS SELECT a FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 25), "foo"), new Query( location(1, 35), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 35), @@ -2150,6 +2155,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE OR REPLACE TABLE foo(x,y) AS SELECT a,b FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 25), "foo"), new Query( location(1, 37), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 37), @@ -2178,6 +2184,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE IF NOT EXISTS foo AS SELECT * FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 28), "foo"), new Query( location(1, 35), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 35), @@ -2202,6 +2209,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE IF NOT EXISTS foo(x) AS SELECT a FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 28), "foo"), new Query( location(1, 38), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 38), @@ -2226,6 +2234,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE IF NOT EXISTS foo(x,y) AS SELECT a,b FROM t")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 28), "foo"), new Query( location(1, 40), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 40), @@ -2254,6 +2263,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo AS SELECT * FROM t WITH NO DATA")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 21), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 21), @@ -2278,6 +2288,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo(x) AS SELECT a FROM t WITH NO DATA")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 24), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 24), @@ -2302,6 +2313,7 @@ public void testCreateTableAsSelect() assertThat(statement("CREATE TABLE foo(x,y) AS SELECT a,b FROM t WITH NO DATA")) .isEqualTo(new CreateTableAsSelect(location(1, 1), qualifiedName(location(1, 14), "foo"), new Query( location(1, 26), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 26), @@ -2338,6 +2350,7 @@ public void testCreateTableAsSelect() qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2380,6 +2393,7 @@ CREATE TABLE foo(x) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2419,6 +2433,7 @@ CREATE TABLE foo(x,y) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2459,6 +2474,7 @@ CREATE TABLE foo(x,y) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2502,6 +2518,7 @@ CREATE TABLE foo(x) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2542,6 +2559,7 @@ CREATE TABLE foo(x,y) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2582,6 +2600,7 @@ CREATE TABLE foo(x,y) qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2625,6 +2644,7 @@ CREATE TABLE foo(x) COMMENT 'test' qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2665,6 +2685,7 @@ CREATE TABLE foo(x,y) COMMENT 'test' qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2705,6 +2726,7 @@ CREATE TABLE foo(x,y) COMMENT 'test' qualifiedName(location(1, 14), "foo"), new Query( location(4, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(4, 1), @@ -2769,6 +2791,7 @@ WITH t(x) AS (VALUES 1) QualifiedName table = QualifiedName.of("foo"); Query query = new Query( + ImmutableList.of(), Optional.of(new With(false, ImmutableList.of( new WithQuery( identifier("t"), @@ -3550,6 +3573,7 @@ public void testWith() { assertStatement("WITH a (t, u) AS (SELECT * FROM x), b AS (SELECT * FROM y) TABLE z", new Query( + ImmutableList.of(), Optional.of(new With(false, ImmutableList.of( new WithQuery( identifier("a"), @@ -3570,6 +3594,7 @@ public void testWith() assertStatement("WITH RECURSIVE a AS (SELECT * FROM x) TABLE y", new Query( + ImmutableList.of(), Optional.of(new With(true, ImmutableList.of( new WithQuery( identifier("a"), @@ -4114,6 +4139,7 @@ public void testShowStatsForQuery() new TableSubquery( new Query( location(1, 17), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 17), @@ -4143,6 +4169,7 @@ public void testShowStatsForQuery() new TableSubquery( new Query( location(1, 17), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 17), @@ -4177,6 +4204,7 @@ WITH t AS (SELECT 1 ) new TableSubquery( new Query( location(2, 4), + ImmutableList.of(), Optional.of( new With( location(2, 4), @@ -4187,6 +4215,7 @@ WITH t AS (SELECT 1 ) new Identifier(location(2, 9), "t", false), new Query( location(2, 15), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(2, 15), @@ -4560,6 +4589,7 @@ public void testCreateMaterializedView() QualifiedName.of(ImmutableList.of(new Identifier(new NodeLocation(1, 26), "a", false))), new Query( new NodeLocation(1, 31), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(1, 31), @@ -4597,6 +4627,7 @@ public void testCreateMaterializedView() new Identifier(new NodeLocation(1, 52), "matview", false))), new Query( new NodeLocation(1, 100), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(1, 100), @@ -4633,6 +4664,7 @@ public void testCreateMaterializedView() QualifiedName.of(ImmutableList.of(new Identifier(new NodeLocation(1, 26), "a", false))), new Query( new NodeLocation(1, 61), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(1, 61), @@ -4673,6 +4705,7 @@ public void testCreateMaterializedView() new Identifier(new NodeLocation(1, 52), "matview", false))), new Query( new NodeLocation(3, 5), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(3, 5), @@ -4721,6 +4754,7 @@ AS WITH a (t, u) AS (SELECT * FROM x), b AS (SELECT * FROM a) TABLE b new Identifier(new NodeLocation(1, 52), "matview", false))), new Query( new NodeLocation(3, 5), + ImmutableList.of(), Optional.of(new With( new NodeLocation(3, 5), false, @@ -4730,6 +4764,7 @@ AS WITH a (t, u) AS (SELECT * FROM x), b AS (SELECT * FROM a) TABLE b new Identifier(new NodeLocation(3, 10), "a", false), new Query( new NodeLocation(3, 23), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(3, 23), @@ -4758,6 +4793,7 @@ AS WITH a (t, u) AS (SELECT * FROM x), b AS (SELECT * FROM a) TABLE b new Identifier(new NodeLocation(3, 41), "b", false), new Query( new NodeLocation(3, 47), + ImmutableList.of(), Optional.empty(), new QuerySpecification( new NodeLocation(3, 47), @@ -5195,6 +5231,7 @@ public void testQueryPeriod() .isEqualTo( new Query( location(1, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 1), @@ -5225,6 +5262,7 @@ public void testQueryPeriod() .isEqualTo( new Query( location(1, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 1), @@ -5567,6 +5605,7 @@ private static Query selectAllFrom(Relation relation) { return new Query( location(1, 1), + ImmutableList.of(), Optional.empty(), new QuerySpecification( location(1, 1), diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index 15bf0ddc766e..8631b75e870a 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -42,10 +42,10 @@ private static Stream statements() return Stream.of( Arguments.of("", "line 1:1: mismatched input ''. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', " + - "'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', "), + "'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', "), Arguments.of("@select", "line 1:1: mismatched input '@'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', " + - "'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', "), + "'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', "), Arguments.of("select * from foo where @what", "line 1:25: mismatched input '@'. Expecting: "), Arguments.of("select * from 'oops", @@ -86,7 +86,7 @@ private static Stream statements() Arguments.of("select foo(DISTINCT ,1)", "line 1:21: mismatched input ','. Expecting: "), Arguments.of("CREATE )", - "line 1:8: mismatched input ')'. Expecting: 'CATALOG', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'VIEW'"), + "line 1:8: mismatched input ')'. Expecting: 'CATALOG', 'FUNCTION', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'VIEW'"), Arguments.of("CREATE TABLE ) AS (VALUES 1)", "line 1:14: mismatched input ')'. Expecting: 'IF', "), Arguments.of("CREATE TABLE foo ", diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserRoutines.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserRoutines.java new file mode 100644 index 000000000000..143a495d0601 --- /dev/null +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserRoutines.java @@ -0,0 +1,356 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.parser; + +import com.google.common.collect.ImmutableList; +import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.AssignmentStatement; +import io.trino.sql.tree.CommentCharacteristic; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.CompoundStatement; +import io.trino.sql.tree.ControlStatement; +import io.trino.sql.tree.CreateFunction; +import io.trino.sql.tree.DataType; +import io.trino.sql.tree.DeterministicCharacteristic; +import io.trino.sql.tree.ElseIfClause; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.sql.tree.GenericDataType; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.IfStatement; +import io.trino.sql.tree.LanguageCharacteristic; +import io.trino.sql.tree.LogicalExpression; +import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.NodeLocation; +import io.trino.sql.tree.ParameterDeclaration; +import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.Query; +import io.trino.sql.tree.QuerySpecification; +import io.trino.sql.tree.ReturnStatement; +import io.trino.sql.tree.ReturnsClause; +import io.trino.sql.tree.SecurityCharacteristic; +import io.trino.sql.tree.Select; +import io.trino.sql.tree.StringLiteral; +import io.trino.sql.tree.VariableDeclaration; +import io.trino.sql.tree.WhileStatement; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static io.trino.sql.QueryUtil.functionCall; +import static io.trino.sql.QueryUtil.identifier; +import static io.trino.sql.QueryUtil.selectList; +import static io.trino.sql.parser.ParserAssert.functionSpecification; +import static io.trino.sql.parser.ParserAssert.statement; +import static io.trino.sql.tree.NullInputCharacteristic.calledOnNullInput; +import static io.trino.sql.tree.NullInputCharacteristic.returnsNullOnNullInput; +import static io.trino.sql.tree.SecurityCharacteristic.Security.DEFINER; +import static io.trino.sql.tree.SecurityCharacteristic.Security.INVOKER; +import static org.assertj.core.api.Assertions.assertThat; + +class TestSqlParserRoutines +{ + @Test + public void testStandaloneFunction() + { + assertThat(functionSpecification("FUNCTION foo() RETURNS bigint RETURN 42")) + .ignoringLocation() + .isEqualTo(new FunctionSpecification( + QualifiedName.of("foo"), + ImmutableList.of(), + returns(type("bigint")), + ImmutableList.of(), + new ReturnStatement(location(), literal(42)))); + } + + @Test + void testInlineFunction() + { + assertThat(statement(""" + WITH + FUNCTION answer() + RETURNS BIGINT + RETURN 42 + SELECT answer() + """)) + .ignoringLocation() + .isEqualTo(query( + new FunctionSpecification( + QualifiedName.of("answer"), + ImmutableList.of(), + returns(type("BIGINT")), + ImmutableList.of(), + new ReturnStatement(location(), literal(42))), + selectList(new FunctionCall(QualifiedName.of("answer"), ImmutableList.of())))); + } + + @Test + void testSimpleFunction() + { + assertThat(statement(""" + CREATE FUNCTION hello(s VARCHAR) + RETURNS varchar + LANGUAGE SQL + DETERMINISTIC + CALLED ON NULL INPUT + SECURITY INVOKER + COMMENT 'hello world function' + RETURN CONCAT('Hello, ', s, '!') + """)) + .ignoringLocation() + .isEqualTo(new CreateFunction( + new FunctionSpecification( + QualifiedName.of("hello"), + ImmutableList.of(parameter("s", type("VARCHAR"))), + returns(type("varchar")), + ImmutableList.of( + new LanguageCharacteristic(identifier("SQL")), + new DeterministicCharacteristic(true), + calledOnNullInput(), + new SecurityCharacteristic(INVOKER), + new CommentCharacteristic("hello world function")), + new ReturnStatement(location(), functionCall( + "CONCAT", + literal("Hello, "), + identifier("s"), + literal("!")))), + false)); + } + + @Test + void testEmptyFunction() + { + assertThat(statement(""" + CREATE OR REPLACE FUNCTION answer() + RETURNS bigint + RETURN 42 + """)) + .ignoringLocation() + .isEqualTo(new CreateFunction( + new FunctionSpecification( + QualifiedName.of("answer"), + ImmutableList.of(), + returns(type("bigint")), + ImmutableList.of(), + new ReturnStatement(location(), literal(42))), + true)); + } + + @Test + void testFibFunction() + { + assertThat(statement(""" + CREATE FUNCTION fib(n bigint) + RETURNS bigint + BEGIN + DECLARE a bigint DEFAULT 1; + DECLARE b bigint DEFAULT 1; + DECLARE c bigint; + IF n <= 2 THEN + RETURN 1; + END IF; + WHILE n > 2 DO + SET n = n - 1; + SET c = a + b; + SET a = b; + SET b = c; + END WHILE; + RETURN c; + END + """)) + .ignoringLocation() + .isEqualTo(new CreateFunction( + new FunctionSpecification( + QualifiedName.of("fib"), + ImmutableList.of(parameter("n", type("bigint"))), + returns(type("bigint")), + ImmutableList.of(), + beginEnd( + ImmutableList.of( + declare("a", type("bigint"), literal(1)), + declare("b", type("bigint"), literal(1)), + declare("c", type("bigint"))), + new IfStatement( + location(), + lte("n", literal(2)), + ImmutableList.of(new ReturnStatement(location(), literal(1))), + ImmutableList.of(), + Optional.empty()), + new WhileStatement( + location(), + Optional.empty(), + gt("n", literal(2)), + ImmutableList.of( + assign("n", minus(identifier("n"), literal(1))), + assign("c", plus(identifier("a"), identifier("b"))), + assign("a", identifier("b")), + assign("b", identifier("c")))), + new ReturnStatement(location(), identifier("c")))), + false)); + } + + @Test + void testFunctionWithIfElseIf() + { + assertThat(statement(""" + CREATE FUNCTION CustomerLevel(p_creditLimit DOUBLE) + RETURNS varchar + RETURNS NULL ON NULL INPUT + SECURITY DEFINER + BEGIN + DECLARE lvl VarChar; + IF p_creditLimit > 50000 THEN + SET lvl = 'PLATINUM'; + ELSEIF (p_creditLimit <= 50000 AND p_creditLimit >= 10000) THEN + SET lvl = 'GOLD'; + ELSEIF p_creditLimit < 10000 THEN + SET lvl = 'SILVER'; + END IF; + RETURN (lvl); + END + """)) + .ignoringLocation() + .isEqualTo(new CreateFunction( + new FunctionSpecification( + QualifiedName.of("CustomerLevel"), + ImmutableList.of(parameter("p_creditLimit", type("DOUBLE"))), + returns(type("varchar")), + ImmutableList.of( + returnsNullOnNullInput(), + new SecurityCharacteristic(DEFINER)), + beginEnd( + ImmutableList.of(declare("lvl", type("VarChar"))), + new IfStatement( + location(), + gt("p_creditLimit", literal(50000)), + ImmutableList.of(assign("lvl", literal("PLATINUM"))), + ImmutableList.of( + elseIf(LogicalExpression.and( + lte("p_creditLimit", literal(50000)), + gte("p_creditLimit", literal(10000))), + assign("lvl", literal("GOLD"))), + elseIf(lt("p_creditLimit", literal(10000)), + assign("lvl", literal("SILVER")))), + Optional.empty()), + new ReturnStatement(location(), identifier("lvl")))), + false)); + } + + private static DataType type(String identifier) + { + return new GenericDataType(Optional.empty(), new Identifier(identifier, false), ImmutableList.of()); + } + + private static ReturnsClause returns(DataType type) + { + return new ReturnsClause(location(), type); + } + + private static VariableDeclaration declare(String name, DataType type) + { + return new VariableDeclaration(location(), ImmutableList.of(new Identifier(name)), type, Optional.empty()); + } + + private static VariableDeclaration declare(String name, DataType type, Expression defaultValue) + { + return new VariableDeclaration(location(), ImmutableList.of(new Identifier(name)), type, Optional.of(defaultValue)); + } + + private static ParameterDeclaration parameter(String name, DataType type) + { + return new ParameterDeclaration(Optional.of(new Identifier(name)), type); + } + + private static AssignmentStatement assign(String name, Expression value) + { + return new AssignmentStatement(location(), new Identifier(name), value); + } + + private static ArithmeticBinaryExpression plus(Expression left, Expression right) + { + return new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, left, right); + } + + private static ArithmeticBinaryExpression minus(Expression left, Expression right) + { + return new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, left, right); + } + + private static ComparisonExpression lt(String name, Expression expression) + { + return new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, identifier(name), expression); + } + + private static ComparisonExpression lte(String name, Expression expression) + { + return new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, identifier(name), expression); + } + + private static ComparisonExpression gt(String name, Expression expression) + { + return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, identifier(name), expression); + } + + private static ComparisonExpression gte(String name, Expression expression) + { + return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, identifier(name), expression); + } + + private static StringLiteral literal(String literal) + { + return new StringLiteral(literal); + } + + private static LongLiteral literal(long literal) + { + return new LongLiteral(String.valueOf(literal)); + } + + private static CompoundStatement beginEnd(List variableDeclarations, ControlStatement... statements) + { + return new CompoundStatement(location(), variableDeclarations, ImmutableList.copyOf(statements)); + } + + private static ElseIfClause elseIf(Expression expression, ControlStatement... statements) + { + return new ElseIfClause(expression, ImmutableList.copyOf(statements)); + } + + private static Query query(FunctionSpecification function, Select select) + { + return new Query( + ImmutableList.of(function), + Optional.empty(), + new QuerySpecification( + select, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + private static NodeLocation location() + { + return new NodeLocation(1, 1); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index e16f80af4ca9..9500f3e4a66b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -147,6 +147,7 @@ public enum StandardErrorCode INVALID_CHECK_CONSTRAINT(123, USER_ERROR), INVALID_CATALOG_PROPERTY(124, USER_ERROR), CATALOG_UNAVAILABLE(125, USER_ERROR), + MISSING_RETURN(126, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java b/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java index ee78aad47ec7..a7b1c684194a 100644 --- a/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java +++ b/service/trino-verifier/src/main/java/io/trino/verifier/QueryRewriter.java @@ -206,10 +206,10 @@ private List getColumns(Connection connection, CreateTableAsSelect creat querySpecification.getOffset(), Optional.of(new Limit(new LongLiteral("0")))); - zeroRowsQuery = new io.trino.sql.tree.Query(createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.empty(), Optional.empty()); + zeroRowsQuery = new io.trino.sql.tree.Query(ImmutableList.of(), createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.empty(), Optional.empty()); } else { - zeroRowsQuery = new io.trino.sql.tree.Query(createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.empty(), Optional.of(new Limit(new LongLiteral("0")))); + zeroRowsQuery = new io.trino.sql.tree.Query(ImmutableList.of(), createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.empty(), Optional.of(new Limit(new LongLiteral("0")))); } ImmutableList.Builder columns = ImmutableList.builder(); diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 53fd58625c51..d968cb2ffe45 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -251,7 +251,7 @@ public void testAnalysisFailure() public void testParseError() throws Exception { - assertFailedQuery("You shall not parse!", "line 1:1: mismatched input 'You'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', "); + assertFailedQuery("You shall not parse!", "line 1:1: mismatched input 'You'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', "); } @Test