From 39501a498507d5086b8c53e56652954eda511186 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 9 Apr 2025 13:04:01 +0300 Subject: [PATCH] feat: fix split statement for create procedure to accoutn for definers Signed-off-by: Manan Gupta --- .../endtoend/vtgate/unsharded/main_test.go | 6 ++ go/vt/sqlparser/parser.go | 63 ++++++++++++++----- go/vt/sqlparser/parser_test.go | 33 ++++++++-- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/go/test/endtoend/vtgate/unsharded/main_test.go b/go/test/endtoend/vtgate/unsharded/main_test.go index 7fe8b864fd8..007be182c76 100644 --- a/go/test/endtoend/vtgate/unsharded/main_test.go +++ b/go/test/endtoend/vtgate/unsharded/main_test.go @@ -145,6 +145,12 @@ BEGIN insert into allDefaults(id) values (128); select 128 into val from dual; END; +`, + `CREATE DEFINER=current_user() PROCEDURE with_definer(OUT val int) +BEGIN + insert into allDefaults(id) values (128); + select 128 into val from dual; +END; `} ) diff --git a/go/vt/sqlparser/parser.go b/go/vt/sqlparser/parser.go index 0fe43db947a..7c1d00416f3 100644 --- a/go/vt/sqlparser/parser.go +++ b/go/vt/sqlparser/parser.go @@ -242,6 +242,38 @@ func (p *Parser) SplitStatement(blob string) (string, string, error) { return blob, "", nil } +var validCreatePrefixes = [][]int{ + // These are the tokens (in order) for valid "create procedure" forms. + {CREATE, PROCEDURE}, + {CREATE, DEFINER, '=', CURRENT_USER, PROCEDURE}, + {CREATE, DEFINER, '=', CURRENT_USER, '(', ')', PROCEDURE}, + {CREATE, DEFINER, '=', STRING, PROCEDURE}, + {CREATE, DEFINER, '=', STRING, AT_ID, PROCEDURE}, + {CREATE, DEFINER, '=', ID, PROCEDURE}, + {CREATE, DEFINER, '=', ID, AT_ID, PROCEDURE}, +} + +// matchesCreateProcedurePrefix checks if the given token sequence +// is a create procedure statement or not. +func matchesCreateProcedurePrefix(tokens []int) bool { + // Check each candidate sequence. + for _, pattern := range validCreatePrefixes { + if len(tokens) >= len(pattern) { + match := true + for i, tok := range pattern { + if tokens[i] != tok { + match = false + break + } + } + if match { + return true + } + } + } + return false +} + // SplitStatementToPieces splits raw sql statement that may have multi sql pieces to sql pieces // returns the sql pieces blob contains; or error if sql cannot be parsed. func (p *Parser) SplitStatementToPieces(blob string) (pieces []string, err error) { @@ -263,27 +295,25 @@ func (p *Parser) SplitStatementToPieces(blob string) (pieces []string, err error var stmt string stmtBegin := 0 emptyStatement := true - var prevToken int - var isCreateProcedureStatement bool + var startTokens []int // holds the first tokens of the current statement + loop: for { tkn, _ = tokenizer.Scan() switch tkn { case ';': + // Potential end of the statement. stmt = blob[stmtBegin : tokenizer.Pos-1] - // We now try to parse the statement to see if its complete. - // If it is a create procedure, then it might not be complete, and we - // would need to scan to the next ; - if isCreateProcedureStatement && p.IsStatementIncomplete(stmt) { + // If it's a create procedure statement and is incomplete, skip appending. + if matchesCreateProcedurePrefix(startTokens) && p.IsStatementIncomplete(stmt) { continue } if !emptyStatement { pieces = append(pieces, stmt) // We can now reset the variables for the next statement. - // It starts off as an empty statement and we don't know if it is - // a create procedure statement yet. + // It starts off as an empty statement. emptyStatement = true - isCreateProcedureStatement = false + startTokens = startTokens[:0] // clear token slice } stmtBegin = tokenizer.Pos case 0, eofChar: @@ -296,16 +326,15 @@ loop: } break loop case COMMENT: - // We want to ignore comments and not store them in the prevToken for knowing - // if the current statement is a create procedure statement. + // Skip comments entirely without altering the token list. continue - case PROCEDURE: - if prevToken == CREATE { - isCreateProcedureStatement = true - } - fallthrough default: - prevToken = tkn + // If we're at the very start of a statement, or we haven't filled out enough tokens + // for our valid prefix match (assuming our longest valid sequence is 10 tokens), + // accumulate the token. + if len(startTokens) < 10 { + startTokens = append(startTokens, tkn) + } emptyStatement = false } } diff --git a/go/vt/sqlparser/parser_test.go b/go/vt/sqlparser/parser_test.go index 107bb5ae34c..3e86b10ba0e 100644 --- a/go/vt/sqlparser/parser_test.go +++ b/go/vt/sqlparser/parser_test.go @@ -137,12 +137,35 @@ func TestSplitStatementToPieces(t *testing.T) { // Test that we don't split on semicolons inside create procedure calls. input: "select * from t1;create procedure p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2", lenWanted: 3, + }, { + // Create procedure with comments. + input: "select * from t1; /* comment1 */ create /* comment2 */ procedure /* comment3 */ p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2", + lenWanted: 3, + }, { + // Create procedure with definer current_user. + input: "create DEFINER=CURRENT_USER procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, + }, { + // Create procedure with definer current_user(). + input: "create DEFINER=CURRENT_USER() procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, + }, { + // Create procedure with definer string. + input: "create DEFINER='root' procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, + }, { + // Create procedure with definer string at_id. + input: "create DEFINER='root'@localhost procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, + }, { + // Create procedure with definer id. + input: "create DEFINER=`root` procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, + }, { + // Create procedure with definer id at_id. + input: "create DEFINER=`root`@`localhost` procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end", + lenWanted: 1, }, - { - // Create procedure with comments. - input: "select * from t1; /* comment1 */ create /* comment2 */ procedure /* comment3 */ p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2", - lenWanted: 3, - }, } parser := NewTestParser()