Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.antlr.v4.runtime.LexerNoViableAltException;
import org.antlr.v4.runtime.Token;

import java.util.HashSet;
import java.util.Set;

/**
Expand All @@ -30,17 +31,26 @@
* The code in nextToken() is a copy of the implementation in org.antlr.v4.runtime.Lexer, with a
* bit added to match the token before the default behavior is invoked.
*/
class DelimiterLexer
public class DelimiterLexer
extends SqlBaseLexer
{
private final Set<String> delimiters;
private final boolean useSemicolon;

public DelimiterLexer(CharStream input, Set<String> delimiters)
{
super(input);
delimiters = new HashSet<>(delimiters);
this.useSemicolon = delimiters.remove(";");
this.delimiters = ImmutableSet.copyOf(delimiters);
}

public boolean isDelimiter(Token token)
{
return (token.getType() == SqlBaseParser.DELIMITER) ||
(useSemicolon && (token.getType() == SEMICOLON));
}

@Override
public Token nextToken()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.grammar.sql.SqlBaseBaseVisitor;
import io.trino.grammar.sql.SqlBaseLexer;
import io.trino.grammar.sql.SqlBaseParser;
import io.trino.grammar.sql.SqlBaseParser.FunctionSpecificationContext;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;

import java.util.List;
import java.util.Objects;
Expand All @@ -39,25 +45,50 @@ public StatementSplitter(String sql)

public StatementSplitter(String sql, Set<String> delimiters)
{
TokenSource tokens = getLexer(sql, delimiters);
DelimiterLexer lexer = getLexer(sql, delimiters);
CommonTokenStream tokenStream = new CommonTokenStream(lexer);
tokenStream.fill();

SqlBaseParser parser = new SqlBaseParser(tokenStream);
parser.removeErrorListeners();

ImmutableList.Builder<Statement> list = ImmutableList.builder();
StringBuilder sb = new StringBuilder();
while (true) {
Token token = tokens.nextToken();
if (token.getType() == Token.EOF) {
break;
}
if (token.getType() == SqlBaseParser.DELIMITER) {
String statement = sb.toString().trim();
if (!statement.isEmpty()) {
list.add(new Statement(statement, token.getText()));
int index = 0;

while (index < tokenStream.size()) {
ParserRuleContext context = parser.statement();

if (containsFunction(context)) {
Token stop = context.getStop();
if ((stop != null) && (stop.getTokenIndex() >= index)) {
int endIndex = stop.getTokenIndex();
while (index <= endIndex) {
Token token = tokenStream.get(index);
index++;
sb.append(token.getText());
}
}
sb = new StringBuilder();
}
else {

while (index < tokenStream.size()) {
Token token = tokenStream.get(index);
index++;
if (token.getType() == Token.EOF) {
break;
}
if (lexer.isDelimiter(token)) {
String statement = sb.toString().trim();
if (!statement.isEmpty()) {
list.add(new Statement(statement, token.getText()));
}
sb = new StringBuilder();
break;
}
sb.append(token.getText());
}
}

this.completeStatements = list.build();
this.partialStatement = sb.toString().trim();
}
Expand Down Expand Up @@ -105,12 +136,42 @@ public static boolean isEmptyStatement(String sql)
}
}

public static TokenSource getLexer(String sql, Set<String> terminators)
public static DelimiterLexer getLexer(String sql, Set<String> terminators)
{
requireNonNull(sql, "sql is null");
return new DelimiterLexer(CharStreams.fromString(sql), terminators);
}

private static boolean containsFunction(ParseTree tree)
{
return new SqlBaseBaseVisitor<Boolean>()
{
@Override
protected Boolean defaultResult()
{
return false;
}

@Override
protected Boolean aggregateResult(Boolean aggregate, Boolean nextResult)
{
return aggregate || nextResult;
}

@Override
protected boolean shouldVisitNextChild(RuleNode node, Boolean currentResult)
{
return !currentResult;
}

@Override
public Boolean visitFunctionSpecification(FunctionSpecificationContext context)
{
return true;
}
}.visit(tree);
}

public static class Statement
{
private final String statement;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,151 @@ public void testSplitterSelectItemsWithoutComma()
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterSimpleInlineFunction()
{
String function = "WITH FUNCTION abc() RETURNS int RETURN 42 SELECT abc() FROM t";
String sql = function + "; SELECT 456;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 456"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterSimpleInlineFunctionWithIncompleteSelect()
{
String function = "WITH FUNCTION abc() RETURNS int RETURN 42 SELECT abc(), FROM t";
String sql = function + "; SELECT 456;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 456"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterSimpleInlineFunctionWithComments()
{
String function = "/* start */ WITH FUNCTION abc() RETURNS int /* middle */ RETURN 42 SELECT abc() FROM t /* end */";
String sql = function + "; SELECT 456;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 456"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterCreateFunction()
{
String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END";
String sql = function + "; SELECT 123;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 123"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterCreateFunctionInvalidThen()
{
String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN oops; END IF; RETURN 1; END";
String sql = function + "; SELECT 123;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 123"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterCreateFunctionInvalidReturn()
{
String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN oops; END IF; RETURN 1 xxx; END";
String sql = function + "; SELECT 123;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 123"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterCreateFunctionInvalidBegin()
{
String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN xxx IF false THEN oops; END IF; RETURN 1; END";
String sql = function + "; SELECT 123;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement("CREATE FUNCTION fib(n int) RETURNS int BEGIN xxx IF false THEN oops; END IF"),
statement("RETURN 1"),
statement("END"),
statement("SELECT 123"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterCreateFunctionInvalidDelimitedThen()
{
String function = "CREATE FUNCTION fib(n int) RETURNS int BEGIN IF false THEN; oops; END IF; RETURN 1; END";
String sql = function + "; SELECT 123;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 123"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterComplexCreateFunction()
{
String function = "" +
"CREATE FUNCTION fib(n bigint)\n" +
"RETURNS bigint\n" +
"BEGIN\n" +
" DECLARE a bigint DEFAULT 1;\n" +
" DECLARE b bigint DEFAULT 1;\n" +
" DECLARE c bigint;\n" +
" IF n <= 2 THEN\n" +
" RETURN 1;\n" +
" END IF;\n" +
" WHILE n > 2 DO\n" +
" SET n = n - 1;\n" +
" SET c = a + b;\n" +
" SET a = b;\n" +
" SET b = c;\n" +
" END WHILE;\n" +
" RETURN c;\n" +
"END";
String sql = function + ";\nSELECT 123;\nSELECT 456;\n";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement(function),
statement("SELECT 123"),
statement("SELECT 456"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testSplitterMultipleFunctions()
{
String function1 = "CREATE FUNCTION f1() RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END";
String function2 = "CREATE FUNCTION f2() RETURNS int BEGIN IF false THEN RETURN 0; END IF; RETURN 1; END";
String sql = "SELECT 11;" + function1 + ";" + function2 + ";SELECT 22;" + function2 + ";SELECT 33;";
StatementSplitter splitter = new StatementSplitter(sql);
assertThat(splitter.getCompleteStatements()).containsExactly(
statement("SELECT 11"),
statement(function1),
statement(function2),
statement("SELECT 22"),
statement(function2),
statement("SELECT 33"));
assertThat(splitter.getPartialStatement()).isEmpty();
}

@Test
public void testIsEmptyStatement()
{
Expand Down
Loading