diff --git a/go/test/endtoend/vtgate/unsharded/main_test.go b/go/test/endtoend/vtgate/unsharded/main_test.go index 007be182c76..21d83e85b86 100644 --- a/go/test/endtoend/vtgate/unsharded/main_test.go +++ b/go/test/endtoend/vtgate/unsharded/main_test.go @@ -151,7 +151,10 @@ BEGIN insert into allDefaults(id) values (128); select 128 into val from dual; END; -`} +`, + `CREATE PROCEDURE p1 (in x BIGINT) BEGIN declare y DECIMAL(14,2); set y = 4.2; END`, + `CREATE PROCEDURE p2 (in x BIGINT) BEGIN START TRANSACTION; SELECT 128 from dual; COMMIT; END`, + } ) func TestMain(m *testing.M) { diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index a81e2f276f4..013b7d8d9e2 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -640,8 +640,15 @@ type ( // TxAccessMode is an enum for Transaction Access Mode TxAccessMode int8 + // BeginType is an enum for the type of BEGIN statement. + BeginType int8 + // Begin represents a Begin statement. Begin struct { + // We need to differentiate between BEGIN and START TRANSACTION statements + // because inside a stored procedure the former is considered part of a BEGIN...END statement, + // while the latter starts a transaction. + Type BeginType TxAccessModes []TxAccessMode } diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 0af55f326b2..b5e54309115 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2071,7 +2071,8 @@ func (cmp *Comparator) RefOfBegin(a, b *Begin) bool { if a == nil || b == nil { return false } - return cmp.SliceOfTxAccessMode(a.TxAccessModes, b.TxAccessModes) + return a.Type == b.Type && + cmp.SliceOfTxAccessMode(a.TxAccessModes, b.TxAccessModes) } // RefOfBeginEndStatement does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 589b0d50103..0163ca59ed7 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1184,7 +1184,7 @@ func (node *Commit) Format(buf *TrackedBuffer) { // Format formats the node. func (node *Begin) Format(buf *TrackedBuffer) { - if node.TxAccessModes == nil { + if node.Type == BeginStmt { buf.literal("begin") return } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 3cc95f7bee5..b015d12b27d 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1560,7 +1560,7 @@ func (node *Commit) FormatFast(buf *TrackedBuffer) { // FormatFast formats the node. func (node *Begin) FormatFast(buf *TrackedBuffer) { - if node.TxAccessModes == nil { + if node.Type == BeginStmt { buf.WriteString("begin") return } diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 7743a86172b..8b3edeab970 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -421,7 +421,7 @@ func (cached *Begin) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(24) + size += int64(32) } // field TxAccessModes []vitess.io/vitess/go/vt/sqlparser.TxAccessMode { diff --git a/go/vt/sqlparser/constants.go b/go/vt/sqlparser/constants.go index 6889141c9d2..f1b614a12cf 100644 --- a/go/vt/sqlparser/constants.go +++ b/go/vt/sqlparser/constants.go @@ -995,6 +995,12 @@ const ( ReadOnly ) +// BEGIN statement type +const ( + BeginStmt BeginType = iota + StartTransactionStmt +) + // Enum Types of WKT functions const ( GeometryFromText GeomFromWktType = iota diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index e435ad0957d..e5002d02f4c 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -174,7 +174,7 @@ func (nz *normalizer) determineQueryRewriteStrategy(in Statement) { func (nz *normalizer) walkDown(node, _ SQLNode) bool { switch node := node.(type) { case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release, *OtherAdmin, *Analyze, - *PrepareStmt, *ExecuteStmt, *FramePoint, *ColName, TableName, *ConvertType: + *PrepareStmt, *ExecuteStmt, *FramePoint, *ColName, TableName, *ConvertType, *CreateProcedure: // These statement do not need normalizing return false case *AssignmentExpr: diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 596df245c26..ce15e64e74d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -441,6 +441,11 @@ func TestNormalize(t *testing.T) { "bv1": sqltypes.Int64BindVariable(1), "bv2": sqltypes.Int64BindVariable(0), }, + }, { + // Verify we don't change anything in the normalization of create procedures. + in: "CREATE PROCEDURE p2 (in x BIGINT) BEGIN declare y DECIMAL(14,2); START TRANSACTION; set y = 4.2; SELECT 128 from dual; COMMIT; END", + outstmt: "create procedure p2 (in x BIGINT) begin declare y DECIMAL(14,2); start transaction; set y = 4.2; select 128 from dual; commit; end;", + outbv: map[string]*querypb.BindVariable{}, }} parser := NewTestParser() for _, tc := range testcases { diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index 94421da8bca..4099e9f91fe 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -2208,6 +2208,8 @@ var ( }, { input: "create procedure ConditionWithSignalAndHandler() begin declare custom_error condition for sqlstate '45000'; declare exit handler for custom_error begin select 'Handled with custom condition and signal'; end; signal sqlstate '45000' set message_text = 'Custom signal triggered'; end;", output: "create procedure ConditionWithSignalAndHandler () begin declare custom_error condition for sqlstate '45000'; declare exit handler for custom_error begin select 'Handled with custom condition and signal' from dual; end; signal sqlstate '45000' set message_text = 'Custom signal triggered'; end;", + }, { + input: "create procedure t1 (in x BIGINT) begin start transaction; insert into unsharded_a values (1, 'a', 'a'); commit; end;", }, { input: "create /*vt+ strategy=online */ or replace view v as select a, b, c from t", }, { @@ -2885,8 +2887,7 @@ var ( input: "begin;", output: "begin", }, { - input: "start transaction", - output: "begin", + input: "start transaction", }, { input: "start transaction with consistent snapshot", }, { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 2ed1d56e6da..5df698e434d 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -17649,7 +17649,7 @@ yydefault: var yyLOCAL Statement //line sql.y:4770 { - yyLOCAL = &Begin{} + yyLOCAL = &Begin{Type: BeginStmt} } yyVAL.union = yyLOCAL case 882: @@ -17657,7 +17657,7 @@ yydefault: var yyLOCAL Statement //line sql.y:4774 { - yyLOCAL = &Begin{TxAccessModes: yyDollar[3].txAccessModesUnion()} + yyLOCAL = &Begin{Type: StartTransactionStmt, TxAccessModes: yyDollar[3].txAccessModesUnion()} } yyVAL.union = yyLOCAL case 883: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 371cdd6fff7..a8ed0716422 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -4768,11 +4768,11 @@ use_table_name: begin_statement: BEGIN { - $$ = &Begin{} + $$ = &Begin{Type: BeginStmt} } | START TRANSACTION tx_chacteristics_opt { - $$ = &Begin{TxAccessModes: $3} + $$ = &Begin{Type: StartTransactionStmt, TxAccessModes: $3} } tx_chacteristics_opt: diff --git a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json index 19da42db9b8..3b84004ca99 100644 --- a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json @@ -139,6 +139,46 @@ ] } }, + { + "comment": "Create procedure with set statement", + "query": "create procedure main.t1 (in x BIGINT) begin declare y DECIMAL(14,2); set y = 4.2; end;", + "plan": { + "Type": "DirectDDL", + "QueryType": "DDL", + "Original": "create procedure main.t1 (in x BIGINT) begin declare y DECIMAL(14,2); set y = 4.2; end;", + "Instructions": { + "OperatorType": "DDL", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Query": "create procedure t1 (in x BIGINT) begin declare y DECIMAL(14,2); set y = 4.2; end;" + }, + "TablesUsed": [ + "main.t1" + ] + } + }, + { + "comment": "Create procedure with a transaction inside", + "query": "create procedure main.t1 (in x BIGINT) begin start transaction; insert into unsharded_a values (1, 'a', 'a'); commit; end;", + "plan": { + "Type": "DirectDDL", + "QueryType": "DDL", + "Original": "create procedure main.t1 (in x BIGINT) begin start transaction; insert into unsharded_a values (1, 'a', 'a'); commit; end;", + "Instructions": { + "OperatorType": "DDL", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Query": "create procedure t1 (in x BIGINT) begin start transaction; insert into unsharded_a values (1, 'a', 'a'); commit; end;" + }, + "TablesUsed": [ + "main.t1" + ] + } + }, { "comment": "DDL", "query": "create table a(id int)",