diff --git a/go/test/endtoend/vtgate/queries/transaction_timeout/schema.sql b/go/test/endtoend/vtgate/queries/transaction_timeout/schema.sql new file mode 100644 index 00000000000..ceac0c07e6d --- /dev/null +++ b/go/test/endtoend/vtgate/queries/transaction_timeout/schema.sql @@ -0,0 +1,5 @@ +create table if not exists t1( + id1 bigint, + id2 bigint, + primary key(id1) +) Engine=InnoDB; \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/transaction_timeout/transaction_timeout_test.go b/go/test/endtoend/vtgate/queries/transaction_timeout/transaction_timeout_test.go new file mode 100644 index 00000000000..1b14229f7d3 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/transaction_timeout/transaction_timeout_test.go @@ -0,0 +1,146 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transactiontimeout + +import ( + "context" + _ "embed" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/test/endtoend/utils" +) + +var ( + clusterInstance *cluster.LocalProcessCluster + vtParams mysql.ConnParams + uks = "uks" + cell = "test_misc" + + //go:embed uschema.sql + uschemaSQL string +) + +func createCluster(t *testing.T, vttabletArgs ...string) func() { + clusterInstance = cluster.NewCluster(cell, "localhost") + + err := clusterInstance.StartTopo() + require.NoError(t, err) + + clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, vttabletArgs...) + + ukeyspace := &cluster.Keyspace{ + Name: uks, + SchemaSQL: uschemaSQL, + } + err = clusterInstance.StartUnshardedKeyspace(*ukeyspace, 0, false) + require.NoError(t, err) + + err = clusterInstance.StartVtgate() + require.NoError(t, err) + + vtParams = clusterInstance.GetVTParams(uks) + + _, closer, err := utils.NewMySQL(clusterInstance, uks, uschemaSQL) + require.NoError(t, err) + + return func() { + clusterInstance.Teardown() + closer() + } +} + +func TestTransactionTimeout(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vttablet") + + // Start cluster with no vtgate or vttablet timeouts + teardown := createCluster(t) + defer teardown() + + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + defer conn.Close() + + // No timeout set, transaction shouldn't timeout + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(sleep(0.5))") + utils.Exec(t, conn, "commit") + + // Set session transaction timeout + utils.Exec(t, conn, "set transaction_timeout=100") + + // Sleeping outside of query will allow the transaction killer to kill the transaction + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + time.Sleep(3 * time.Second) + _, err = utils.ExecAllowError(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + require.ErrorContains(t, err, "Aborted") + + // Sleeping in MySQL will cause a context timeout instead (different error) + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + _, err = utils.ExecAllowError(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(sleep(0.5))") + require.ErrorContains(t, err, "Query execution was interrupted") + + // Get new connection + conn, err = mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + + // Set session transaction timeout to 0 + utils.Exec(t, conn, "set transaction_timeout=0") + + // Should time out using tablet transaction timeout + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(sleep(2))") + utils.Exec(t, conn, "commit") +} + +func TestSmallerTimeout(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vttablet") + + // Start vttablet with a transaction timeout + teardown := createCluster(t, "--queryserver-config-transaction-timeout", "1s") + defer teardown() + + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + + // Set session transaction timeout larger than tablet transaction timeout + utils.Exec(t, conn, "set transaction_timeout=2000") + + // Transaction should get killed with lower timeout + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + time.Sleep(1500 * time.Millisecond) + _, err = utils.ExecAllowError(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + require.ErrorContains(t, err, "Aborted") + + // Set session transaction timeout smaller than tablet transaction timeout + utils.Exec(t, conn, "set transaction_timeout=250") + + // Session timeout should be used this time + utils.Exec(t, conn, "begin") + utils.Exec(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + time.Sleep(500 * time.Millisecond) + _, err = utils.ExecAllowError(t, conn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + require.ErrorContains(t, err, "Aborted") +} diff --git a/go/test/endtoend/vtgate/queries/transaction_timeout/uschema.sql b/go/test/endtoend/vtgate/queries/transaction_timeout/uschema.sql new file mode 100644 index 00000000000..6ba158b134e --- /dev/null +++ b/go/test/endtoend/vtgate/queries/transaction_timeout/uschema.sql @@ -0,0 +1,5 @@ +create table unsharded( + id1 bigint, + id2 bigint, + key(id1) +) Engine=InnoDB; \ No newline at end of file diff --git a/go/vt/proto/query/query.pb.go b/go/vt/proto/query/query.pb.go index 73440c8cbcd..6f86f5ad99d 100644 --- a/go/vt/proto/query/query.pb.go +++ b/go/vt/proto/query/query.pb.go @@ -1402,8 +1402,10 @@ type ExecuteOptions struct { FetchLastInsertId bool `protobuf:"varint,18,opt,name=fetch_last_insert_id,json=fetchLastInsertId,proto3" json:"fetch_last_insert_id,omitempty"` // in_dml_execution indicates that the query is being executed as part of a DML execution. InDmlExecution bool `protobuf:"varint,19,opt,name=in_dml_execution,json=inDmlExecution,proto3" json:"in_dml_execution,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // transaction_timeout specifies the transaction timeout in milliseconds. If not set, the default timeout is used. + TransactionTimeout *int64 `protobuf:"varint,20,opt,name=transaction_timeout,json=transactionTimeout,proto3,oneof" json:"transaction_timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExecuteOptions) Reset() { @@ -1550,6 +1552,13 @@ func (x *ExecuteOptions) GetInDmlExecution() bool { return false } +func (x *ExecuteOptions) GetTransactionTimeout() int64 { + if x != nil && x.TransactionTimeout != nil { + return *x.TransactionTimeout + } + return 0 +} + type isExecuteOptions_Timeout interface { isExecuteOptions_Timeout() } @@ -5817,7 +5826,7 @@ const file_query_proto_rawDesc = "" + "\x0ebind_variables\x18\x02 \x03(\v2$.query.BoundQuery.BindVariablesEntryR\rbindVariables\x1aU\n" + "\x12BindVariablesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\xb5\f\n" + + "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\x83\r\n" + "\x0eExecuteOptions\x12M\n" + "\x0fincluded_fields\x18\x04 \x01(\x0e2$.query.ExecuteOptions.IncludedFieldsR\x0eincludedFields\x12*\n" + "\x11client_found_rows\x18\x05 \x01(\bR\x0fclientFoundRows\x12:\n" + @@ -5834,7 +5843,8 @@ const file_query_proto_rawDesc = "" + "\bpriority\x18\x10 \x01(\tR\bpriority\x125\n" + "\x15authoritative_timeout\x18\x11 \x01(\x03H\x00R\x14authoritativeTimeout\x12/\n" + "\x14fetch_last_insert_id\x18\x12 \x01(\bR\x11fetchLastInsertId\x12(\n" + - "\x10in_dml_execution\x18\x13 \x01(\bR\x0einDmlExecution\";\n" + + "\x10in_dml_execution\x18\x13 \x01(\bR\x0einDmlExecution\x124\n" + + "\x13transaction_timeout\x18\x14 \x01(\x03H\x01R\x12transactionTimeout\x88\x01\x01\";\n" + "\x0eIncludedFields\x12\x11\n" + "\rTYPE_AND_NAME\x10\x00\x12\r\n" + "\tTYPE_ONLY\x10\x01\x12\a\n" + @@ -5873,7 +5883,8 @@ const file_query_proto_rawDesc = "" + "\n" + "READ_WRITE\x10\x01\x12\r\n" + "\tREAD_ONLY\x10\x02B\t\n" + - "\atimeoutJ\x04\b\x01\x10\x02J\x04\b\x02\x10\x03J\x04\b\x03\x10\x04\"\xb8\x02\n" + + "\atimeoutB\x16\n" + + "\x14_transaction_timeoutJ\x04\b\x01\x10\x02J\x04\b\x02\x10\x03J\x04\b\x03\x10\x04\"\xb8\x02\n" + "\x05Field\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n" + "\x04type\x18\x02 \x01(\x0e2\v.query.TypeR\x04type\x12\x14\n" + diff --git a/go/vt/proto/query/query_vtproto.pb.go b/go/vt/proto/query/query_vtproto.pb.go index b2a752a3ed4..4a735071e57 100644 --- a/go/vt/proto/query/query_vtproto.pb.go +++ b/go/vt/proto/query/query_vtproto.pb.go @@ -188,6 +188,10 @@ func (m *ExecuteOptions) CloneVT() *ExecuteOptions { CloneVT() isExecuteOptions_Timeout }).CloneVT() } + if rhs := m.TransactionTimeout; rhs != nil { + tmpVal := *rhs + r.TransactionTimeout = &tmpVal + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -1896,6 +1900,13 @@ func (m *ExecuteOptions) MarshalToSizedBufferVT(dAtA []byte) (int, error) { } i -= size } + if m.TransactionTimeout != nil { + i = protohelpers.EncodeVarint(dAtA, i, uint64(*m.TransactionTimeout)) + i-- + dAtA[i] = 0x1 + i-- + dAtA[i] = 0xa0 + } if m.InDmlExecution { i-- if m.InDmlExecution { @@ -6212,6 +6223,9 @@ func (m *ExecuteOptions) SizeVT() (n int) { if m.InDmlExecution { n += 3 } + if m.TransactionTimeout != nil { + n += 2 + protohelpers.SizeOfVarint(uint64(*m.TransactionTimeout)) + } n += len(m.unknownFields) return n } @@ -9000,6 +9014,26 @@ func (m *ExecuteOptions) UnmarshalVT(dAtA []byte) error { } } m.InDmlExecution = bool(v != 0) + case 20: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TransactionTimeout", wireType) + } + var v int64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.TransactionTimeout = &v default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index e5f1b9bc040..a985b99f6a7 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -743,6 +743,7 @@ func (nz *normalizer) sysVarRewrite(cursor *Cursor, node *Variable) { sysvars.Version.Name, sysvars.VersionComment.Name, sysvars.QueryTimeout.Name, + sysvars.TransactionTimeout.Name, sysvars.Workload.Name: found = true } diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 1a59ae57c13..0b426ff148c 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -587,7 +587,7 @@ type myTestCase struct { ddlStrategy, migrationContext, sessionUUID, sessionEnableSystemSettings bool udv int autocommit, foreignKeyChecks, clientFoundRows, skipQueryPlanCache, socket, queryTimeout bool - sqlSelectLimit, transactionMode, workload, version, versionComment bool + sqlSelectLimit, transactionMode, workload, version, versionComment, transactionTimeout bool } func TestRewrites(in *testing.T) { @@ -603,6 +603,10 @@ func TestRewrites(in *testing.T) { in: "SELECT @@query_timeout", expected: "SELECT :__vtquery_timeout as `@@query_timeout`", queryTimeout: true, + }, { + in: "SELECT @@transaction_timeout", + expected: "SELECT :__vttransaction_timeout as `@@transaction_timeout`", + transactionTimeout: true, }, { in: "SELECT @@version_comment", expected: "SELECT :__vtversion_comment as `@@version_comment`", @@ -862,6 +866,7 @@ func TestRewrites(in *testing.T) { sessTrackGTID: true, socket: true, queryTimeout: true, + transactionTimeout: true, }, { in: "SHOW GLOBAL VARIABLES", expected: "SHOW GLOBAL VARIABLES", @@ -883,6 +888,7 @@ func TestRewrites(in *testing.T) { sessTrackGTID: true, socket: true, queryTimeout: true, + transactionTimeout: true, }} parser := NewTestParser() for _, tc := range tests { @@ -924,6 +930,7 @@ func TestRewrites(in *testing.T) { assert.Equal(tc.transactionMode, result.NeedsSysVar(sysvars.TransactionMode.Name), "should need :__vttransactionMode") assert.Equal(tc.workload, result.NeedsSysVar(sysvars.Workload.Name), "should need :__vtworkload") assert.Equal(tc.queryTimeout, result.NeedsSysVar(sysvars.QueryTimeout.Name), "should need :__vtquery_timeout") + assert.Equal(tc.transactionTimeout, result.NeedsSysVar(sysvars.TransactionTimeout.Name), "should need :__vttransaction_timeout") assert.Equal(tc.ddlStrategy, result.NeedsSysVar(sysvars.DDLStrategy.Name), "should need ddlStrategy") assert.Equal(tc.migrationContext, result.NeedsSysVar(sysvars.MigrationContext.Name), "should need migrationContext") assert.Equal(tc.sessionUUID, result.NeedsSysVar(sysvars.SessionUUID.Name), "should need sessionUUID") diff --git a/go/vt/sysvars/sysvars.go b/go/vt/sysvars/sysvars.go index 17e495ee9ca..2af5f820abc 100644 --- a/go/vt/sysvars/sysvars.go +++ b/go/vt/sysvars/sysvars.go @@ -73,6 +73,7 @@ var ( TxReadOnly = SystemVariable{Name: "tx_read_only", IsBoolean: true, Default: off} Workload = SystemVariable{Name: "workload", IdentifierAsString: true} QueryTimeout = SystemVariable{Name: "query_timeout"} + TransactionTimeout = SystemVariable{Name: "transaction_timeout"} // Online DDL DDLStrategy = SystemVariable{Name: "ddl_strategy", IdentifierAsString: true} @@ -109,6 +110,7 @@ var ( ReadAfterWriteTimeOut, SessionTrackGTIDs, QueryTimeout, + TransactionTimeout, } ReadOnly = []SystemVariable{ diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 5ff6b5f04c8..9a95dce893e 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -321,6 +321,8 @@ func (t *noopVCursor) SetClientFoundRows(context.Context, bool) error { func (t *noopVCursor) SetQueryTimeout(maxExecutionTime int64) { } +func (t *noopVCursor) SetTransactionTimeout(timeout int64) {} + func (t *noopVCursor) SetSkipQueryPlanCache(context.Context, bool) error { panic("implement me") } @@ -414,6 +416,7 @@ func (t *noopVCursor) DisableLogging() {} func (t *noopVCursor) GetVExplainLogs() []ExecuteEntry { return nil } + func (t *noopVCursor) GetLogs() ([]ExecuteEntry, error) { return nil, nil } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index ce77769d38b..f017045ef99 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -231,6 +231,9 @@ type ( // SetQueryTimeout sets the query timeout SetQueryTimeout(queryTimeout int64) + // SetTransactionTimeout sets the transaction timeout. + SetTransactionTimeout(transactionTimeout int64) + // InTransaction returns true if the session has already opened transaction or // will start a transaction on the query execution. InTransaction() bool diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index f2087000d49..7924a2b76aa 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -476,6 +476,12 @@ func (svss *SysVarSetAware) Execute(ctx context.Context, vcursor VCursor, env *e return err } vcursor.Session().SetQueryTimeout(queryTimeout) + case sysvars.TransactionTimeout.Name: + transactionTimeout, err := svss.evalAsInt64(env, vcursor) + if err != nil { + return err + } + vcursor.Session().SetTransactionTimeout(transactionTimeout) case sysvars.SessionEnableSystemSettings.Name: err = svss.setBoolSysVar(ctx, env, vcursor.Session().SetSessionEnableSystemSettings) case sysvars.Charset.Name, sysvars.Names.Name: diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 1680f4e9235..84bdef0c1ff 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -488,6 +488,12 @@ func (e *Executor) addNeededBindVars(vcursor *econtext.VCursorImpl, bindVarNeeds bindVars[key] = sqltypes.BoolBindVariable(session.Autocommit) case sysvars.QueryTimeout.Name: bindVars[key] = sqltypes.Int64BindVariable(session.GetQueryTimeout()) + case sysvars.TransactionTimeout.Name: + var v int64 + ifOptionsExist(session, func(options *querypb.ExecuteOptions) { + v = options.GetTransactionTimeout() + }) + bindVars[key] = sqltypes.Int64BindVariable(v) case sysvars.ClientFoundRows.Name: var v bool ifOptionsExist(session, func(options *querypb.ExecuteOptions) { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 4310e7c98eb..bd873315f92 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -805,7 +805,8 @@ func TestSelectSystemVariables(t *testing.T) { sql := "select @@autocommit, @@client_found_rows, @@skip_query_plan_cache, @@enable_system_settings, " + "@@sql_select_limit, @@transaction_mode, @@workload, @@read_after_write_gtid, " + - "@@read_after_write_timeout, @@session_track_gtids, @@ddl_strategy, @@migration_context, @@socket, @@query_timeout" + "@@read_after_write_timeout, @@session_track_gtids, @@ddl_strategy, @@migration_context, @@socket, @@query_timeout, " + + "@@transaction_timeout" result, err := executorExec(ctx, executor, session, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ @@ -824,6 +825,7 @@ func TestSelectSystemVariables(t *testing.T) { {Name: "@@migration_context", Type: sqltypes.VarChar, Charset: uint32(collations.MySQL8().DefaultConnectionCharset())}, {Name: "@@socket", Type: sqltypes.VarChar, Charset: uint32(collations.MySQL8().DefaultConnectionCharset())}, {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@transaction_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ // the following are the uninitialised session values @@ -842,6 +844,7 @@ func TestSelectSystemVariables(t *testing.T) { sqltypes.NewVarChar(""), sqltypes.NewVarChar(""), sqltypes.NewInt64(0), + sqltypes.NewInt64(0), }}, } require.NoError(t, err) diff --git a/go/vt/vtgate/executor_set_test.go b/go/vt/vtgate/executor_set_test.go index c48232c5dd8..88160492445 100644 --- a/go/vt/vtgate/executor_set_test.go +++ b/go/vt/vtgate/executor_set_test.go @@ -21,6 +21,7 @@ import ( "testing" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/ptr" querypb "vitess.io/vitess/go/vt/proto/query" econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" @@ -264,6 +265,12 @@ func TestExecutorSet(t *testing.T) { }, { in: "set @@query_timeout = 50, query_timeout = 75", out: &vtgatepb.Session{Autocommit: true, QueryTimeout: 75}, + }, { + in: "set @@transaction_timeout = 50", + out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{TransactionTimeout: ptr.Of(int64(50))}}, + }, { + in: "set @@transaction_timeout = 50, transaction_timeout = 75", + out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{TransactionTimeout: ptr.Of(int64(75))}}, }} for i, tcase := range testcases { t.Run(fmt.Sprintf("%d-%s", i, tcase.in), func(t *testing.T) { diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index 6fb8568f668..cb09d7742c4 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -173,6 +173,7 @@ type ( vm VSchemaOperator semTable *semantics.SemTable queryTimeout time.Duration + transactionTimeout time.Duration warnings []*querypb.QueryWarning // any warnings that are accumulated during the planning phase are stored here @@ -1152,6 +1153,11 @@ func (vc *VCursorImpl) SetQueryTimeout(maxExecutionTime int64) { vc.SafeSession.QueryTimeout = maxExecutionTime } +// SetTransactionTimeout implements the SessionActions interface +func (vc *VCursorImpl) SetTransactionTimeout(transactionTimeout int64) { + vc.SafeSession.GetOrCreateOptions().TransactionTimeout = &transactionTimeout +} + // SetClientFoundRows implements the SessionActions interface func (vc *VCursorImpl) SetClientFoundRows(_ context.Context, clientFoundRows bool) error { vc.SafeSession.GetOrCreateOptions().ClientFoundRows = clientFoundRows diff --git a/go/vt/vttablet/tabletserver/stateful_connection_pool.go b/go/vt/vttablet/tabletserver/stateful_connection_pool.go index 88fbc56fd0c..b5bd1d40332 100644 --- a/go/vt/vttablet/tabletserver/stateful_connection_pool.go +++ b/go/vt/vttablet/tabletserver/stateful_connection_pool.go @@ -200,7 +200,8 @@ func (sf *StatefulConnectionPool) NewConn(ctx context.Context, options *querypb. enforceTimeout: options.GetWorkload() != querypb.ExecuteOptions_DBA, } // This will set both the timeout and initialize the expiryTime. - sfConn.SetTimeout(sf.env.Config().TxTimeoutForWorkload(options.GetWorkload())) + timeout := getTransactionTimeout(options, sf.env.Config(), options.GetWorkload()) + sfConn.SetTimeout(timeout) err = sf.active.Register(sfConn.ConnID, sfConn) if err != nil { diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 0a91b0dc7d6..a03d0c1b8aa 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -248,7 +248,7 @@ func (tsv *TabletServer) loadQueryTimeoutWithTxAndOptions(txID int64, options *q } // fetch the transaction timeout. - txTimeout := tsv.config.TxTimeoutForWorkload(querypb.ExecuteOptions_OLTP) + txTimeout := getTransactionTimeout(options, tsv.config, querypb.ExecuteOptions_OLTP) // Use the smaller of the two values (0 means infinity). return smallerTimeout(timeout, txTimeout) @@ -996,7 +996,7 @@ func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Targ allowOnShutdown = true // Use the transaction timeout. StreamExecute calls happen for OLAP only, // so we can directly fetch the OLAP TX timeout. - timeout = tsv.config.TxTimeoutForWorkload(querypb.ExecuteOptions_OLAP) + timeout = getTransactionTimeout(options, tsv.config, querypb.ExecuteOptions_OLAP) } return tsv.execRequest( diff --git a/go/vt/vttablet/tabletserver/tx_pool.go b/go/vt/vttablet/tabletserver/tx_pool.go index 302a3d41050..cca44056608 100644 --- a/go/vt/vttablet/tabletserver/tx_pool.go +++ b/go/vt/vttablet/tabletserver/tx_pool.go @@ -241,7 +241,7 @@ func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, re return nil, "", "", vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction %d: %v", reservedID, err) } // Update conn timeout. - timeout := tp.env.Config().TxTimeoutForWorkload(options.GetWorkload()) + timeout := getTransactionTimeout(options, tp.env.Config(), options.GetWorkload()) conn.SetTimeout(timeout) } else { immediateCaller := callerid.ImmediateCallerIDFromContext(ctx) @@ -275,6 +275,19 @@ func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, re return conn, sql, sessionStateChanges, nil } +// getTransactionTimeout gets the smaller transaction timeout of either the timeout set in the options +// or the one configured for the current workload. +func getTransactionTimeout(options *querypb.ExecuteOptions, config *tabletenv.TabletConfig, workload querypb.ExecuteOptions_Workload) time.Duration { + workloadTimeout := config.TxTimeoutForWorkload(workload) + + if options != nil && options.TransactionTimeout != nil { + sessionTimeout := time.Duration(options.GetTransactionTimeout()) * time.Millisecond + return smallerTimeout(sessionTimeout, workloadTimeout) + } + + return workloadTimeout +} + func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, conn *StatefulConnection) (string, string, error) { immediateCaller := callerid.ImmediateCallerIDFromContext(ctx) effectiveCaller := callerid.EffectiveCallerIDFromContext(ctx) diff --git a/go/vt/vttablet/tabletserver/tx_pool_test.go b/go/vt/vttablet/tabletserver/tx_pool_test.go index 22810d4c422..5a6371fb5c0 100644 --- a/go/vt/vttablet/tabletserver/tx_pool_test.go +++ b/go/vt/vttablet/tabletserver/tx_pool_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/dbconfigs" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -787,6 +788,37 @@ func TestTxPoolBeginStatements(t *testing.T) { } } +func TestGetTransactionTimeout(t *testing.T) { + _, txPool, _, closer := setup(t) + defer closer() + + txPool.env.Config().Oltp.TxTimeout = 5 * time.Millisecond + + // No options should use workload timeout + timeout := getTransactionTimeout(nil, txPool.env.Config(), querypb.ExecuteOptions_OLTP) + require.Equal(t, 5*time.Millisecond, timeout) + + // Options with no timeout should use workload timeout + options := &querypb.ExecuteOptions{Workload: querypb.ExecuteOptions_OLTP} + timeout = getTransactionTimeout(options, txPool.env.Config(), options.Workload) + require.Equal(t, 5*time.Millisecond, timeout) + + // Options with larger timeout should use smaller workload timeout + options.TransactionTimeout = ptr.Of(int64(10)) // ms + timeout = getTransactionTimeout(options, txPool.env.Config(), options.Workload) + require.Equal(t, 5*time.Millisecond, timeout) + + // Options with smaller timeout should use smaller session timeout + options.TransactionTimeout = ptr.Of(int64(3)) // ms + timeout = getTransactionTimeout(options, txPool.env.Config(), options.Workload) + require.Equal(t, 3*time.Millisecond, timeout) + + // Options with explicit zero timeout should use larger workload timeout + options.TransactionTimeout = ptr.Of(int64(0)) + timeout = getTransactionTimeout(options, txPool.env.Config(), options.Workload) + require.Equal(t, 5*time.Millisecond, timeout) +} + func newTxPool() (*TxPool, *fakeLimiter) { return newTxPoolWithEnv(newEnv("TabletServerTest")) } diff --git a/proto/query.proto b/proto/query.proto index b4ccb61af59..6c15ff22bfb 100644 --- a/proto/query.proto +++ b/proto/query.proto @@ -376,6 +376,9 @@ message ExecuteOptions { // in_dml_execution indicates that the query is being executed as part of a DML execution. bool in_dml_execution = 19; + + // transaction_timeout specifies the transaction timeout in milliseconds. If not set, the default timeout is used. + optional int64 transaction_timeout = 20; } // Field describes a single column returned by a query diff --git a/web/vtadmin/src/proto/vtadmin.d.ts b/web/vtadmin/src/proto/vtadmin.d.ts index 67e8dd462da..00cd6142e40 100644 --- a/web/vtadmin/src/proto/vtadmin.d.ts +++ b/web/vtadmin/src/proto/vtadmin.d.ts @@ -41676,6 +41676,9 @@ export namespace query { /** ExecuteOptions in_dml_execution */ in_dml_execution?: (boolean|null); + + /** ExecuteOptions transaction_timeout */ + transaction_timeout?: (number|Long|null); } /** Represents an ExecuteOptions. */ @@ -41732,6 +41735,9 @@ export namespace query { /** ExecuteOptions in_dml_execution. */ public in_dml_execution: boolean; + /** ExecuteOptions transaction_timeout. */ + public transaction_timeout?: (number|Long|null); + /** ExecuteOptions timeout. */ public timeout?: "authoritative_timeout"; diff --git a/web/vtadmin/src/proto/vtadmin.js b/web/vtadmin/src/proto/vtadmin.js index 86d78c1c718..3e511887a80 100644 --- a/web/vtadmin/src/proto/vtadmin.js +++ b/web/vtadmin/src/proto/vtadmin.js @@ -99385,6 +99385,7 @@ export const query = $root.query = (() => { * @property {number|Long|null} [authoritative_timeout] ExecuteOptions authoritative_timeout * @property {boolean|null} [fetch_last_insert_id] ExecuteOptions fetch_last_insert_id * @property {boolean|null} [in_dml_execution] ExecuteOptions in_dml_execution + * @property {number|Long|null} [transaction_timeout] ExecuteOptions transaction_timeout */ /** @@ -99523,6 +99524,14 @@ export const query = $root.query = (() => { */ ExecuteOptions.prototype.in_dml_execution = false; + /** + * ExecuteOptions transaction_timeout. + * @member {number|Long|null|undefined} transaction_timeout + * @memberof query.ExecuteOptions + * @instance + */ + ExecuteOptions.prototype.transaction_timeout = null; + // OneOf field names bound to virtual getters and setters let $oneOfFields; @@ -99537,6 +99546,12 @@ export const query = $root.query = (() => { set: $util.oneOfSetter($oneOfFields) }); + // Virtual OneOf for proto3 optional field + Object.defineProperty(ExecuteOptions.prototype, "_transaction_timeout", { + get: $util.oneOfGetter($oneOfFields = ["transaction_timeout"]), + set: $util.oneOfSetter($oneOfFields) + }); + /** * Creates a new ExecuteOptions instance using the specified properties. * @function create @@ -99595,6 +99610,8 @@ export const query = $root.query = (() => { writer.uint32(/* id 18, wireType 0 =*/144).bool(message.fetch_last_insert_id); if (message.in_dml_execution != null && Object.hasOwnProperty.call(message, "in_dml_execution")) writer.uint32(/* id 19, wireType 0 =*/152).bool(message.in_dml_execution); + if (message.transaction_timeout != null && Object.hasOwnProperty.call(message, "transaction_timeout")) + writer.uint32(/* id 20, wireType 0 =*/160).int64(message.transaction_timeout); return writer; }; @@ -99696,6 +99713,10 @@ export const query = $root.query = (() => { message.in_dml_execution = reader.bool(); break; } + case 20: { + message.transaction_timeout = reader.int64(); + break; + } default: reader.skipType(tag & 7); break; @@ -99830,6 +99851,11 @@ export const query = $root.query = (() => { if (message.in_dml_execution != null && message.hasOwnProperty("in_dml_execution")) if (typeof message.in_dml_execution !== "boolean") return "in_dml_execution: boolean expected"; + if (message.transaction_timeout != null && message.hasOwnProperty("transaction_timeout")) { + properties._transaction_timeout = 1; + if (!$util.isInteger(message.transaction_timeout) && !(message.transaction_timeout && $util.isInteger(message.transaction_timeout.low) && $util.isInteger(message.transaction_timeout.high))) + return "transaction_timeout: integer|Long expected"; + } return null; }; @@ -100046,6 +100072,15 @@ export const query = $root.query = (() => { message.fetch_last_insert_id = Boolean(object.fetch_last_insert_id); if (object.in_dml_execution != null) message.in_dml_execution = Boolean(object.in_dml_execution); + if (object.transaction_timeout != null) + if ($util.Long) + (message.transaction_timeout = $util.Long.fromValue(object.transaction_timeout)).unsigned = false; + else if (typeof object.transaction_timeout === "string") + message.transaction_timeout = parseInt(object.transaction_timeout, 10); + else if (typeof object.transaction_timeout === "number") + message.transaction_timeout = object.transaction_timeout; + else if (typeof object.transaction_timeout === "object") + message.transaction_timeout = new $util.LongBits(object.transaction_timeout.low >>> 0, object.transaction_timeout.high >>> 0).toNumber(); return message; }; @@ -100125,6 +100160,14 @@ export const query = $root.query = (() => { object.fetch_last_insert_id = message.fetch_last_insert_id; if (message.in_dml_execution != null && message.hasOwnProperty("in_dml_execution")) object.in_dml_execution = message.in_dml_execution; + if (message.transaction_timeout != null && message.hasOwnProperty("transaction_timeout")) { + if (typeof message.transaction_timeout === "number") + object.transaction_timeout = options.longs === String ? String(message.transaction_timeout) : message.transaction_timeout; + else + object.transaction_timeout = options.longs === String ? $util.Long.prototype.toString.call(message.transaction_timeout) : options.longs === Number ? new $util.LongBits(message.transaction_timeout.low >>> 0, message.transaction_timeout.high >>> 0).toNumber() : message.transaction_timeout; + if (options.oneofs) + object._transaction_timeout = "transaction_timeout"; + } return object; };