Skip to content

Commit

Permalink
*: prevent cursor read from being cancelled by GC (#39950)
Browse files Browse the repository at this point in the history
close #39447
  • Loading branch information
zyguan authored Dec 16, 2022
1 parent 53572f8 commit 0fe61bd
Show file tree
Hide file tree
Showing 13 changed files with 375 additions and 26 deletions.
26 changes: 16 additions & 10 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1584,13 +1584,11 @@ func TestLogAndShowSlowLog(t *testing.T) {
}

func TestReportingMinStartTimestamp(t *testing.T) {
_, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
tk := testkit.NewTestKit(t, store)
se := tk.Session()

infoSyncer := dom.InfoSyncer()
sm := &testkit.MockSessionManager{
PS: make([]*util.ProcessInfo, 0),
}
infoSyncer.SetSessionManager(sm)
beforeTS := oracle.GoTimeToTS(time.Now())
infoSyncer.ReportMinStartTS(dom.Store())
afterTS := oracle.GoTimeToTS(time.Now())
Expand All @@ -1599,13 +1597,21 @@ func TestReportingMinStartTimestamp(t *testing.T) {
now := time.Now()
validTS := oracle.GoTimeToLowerLimitStartTS(now.Add(time.Minute), tikv.MaxTxnTimeUse)
lowerLimit := oracle.GoTimeToLowerLimitStartTS(now, tikv.MaxTxnTimeUse)
sm := se.GetSessionManager().(*testkit.MockSessionManager)
sm.PS = []*util.ProcessInfo{
{CurTxnStartTS: 0},
{CurTxnStartTS: math.MaxUint64},
{CurTxnStartTS: lowerLimit},
{CurTxnStartTS: validTS},
{CurTxnStartTS: 0, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: math.MaxUint64, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: lowerLimit, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: validTS, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
}
infoSyncer.SetSessionManager(sm)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())

unhold := se.GetSessionVars().ProtectedTSList.HoldTS(validTS - 1)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS-1, infoSyncer.GetMinStartTS())

unhold()
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())
}
Expand Down
16 changes: 2 additions & 14 deletions domain/infosync/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
if sm == nil {
return
}
pl := sm.ShowProcessList()
innerSessionStartTSList := sm.GetInternalSessionStartTSList()

// Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC.
currentVer, err := store.CurrentVersion(kv.GlobalTxnScope)
Expand All @@ -704,18 +702,8 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
minStartTS := oracle.GoTimeToTS(now)
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS),
zap.Uint64("StartTSLowerLimit", startTSLowerLimit))
for _, info := range pl {
if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS {
minStartTS = info.CurTxnStartTS
}
}

for _, innerTS := range innerSessionStartTSList {
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS))
kv.PrintLongTimeInternalTxn(now, innerTS, false)
if innerTS > startTSLowerLimit && innerTS < minStartTS {
minStartTS = innerTS
}
if ts := sm.GetMinStartTS(startTSLowerLimit); ts > startTSLowerLimit && ts < minStartTS {
minStartTS = ts
}

is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS)
Expand Down
1 change: 1 addition & 0 deletions server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ go_test(
"//util/plancodec",
"//util/resourcegrouptag",
"//util/rowcodec",
"//util/sqlexec",
"//util/topsql",
"//util/topsql/collector",
"//util/topsql/collector/mock",
Expand Down
10 changes: 9 additions & 1 deletion server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if rs == nil {
return false, cc.writeOK(ctx)
}
if result, ok := rs.(*tidbResultSet); ok {
// since there are multiple implementations of ResultSet (the rs might be wrapped), we have to unwrap the rs before
// casting it to *tidbResultSet.
if result, ok := unwrapResultSet(rs).(*tidbResultSet); ok {
if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok {
result.preparedStmt = planCacheStmt
}
Expand All @@ -278,6 +280,12 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if useCursor {
cc.initResultEncoder(ctx)
defer cc.rsEncoder.clean()
// fix https://github.com/pingcap/tidb/issues/39447. we need to hold the start-ts here because the process info
// will be set to sleep after fetch returned.
if pi := cc.ctx.ShowProcess(); pi != nil && pi.ProtectedTSList != nil && pi.CurTxnStartTS > 0 {
unhold := pi.HoldTS(pi.CurTxnStartTS)
rs = &rsWithHooks{ResultSet: rs, onClosed: unhold}
}
stmt.StoreResultSet(rs)
if err = cc.writeColumnInfo(rs.Columns()); err != nil {
return false, err
Expand Down
89 changes: 89 additions & 0 deletions server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package server

import (
"context"
"encoding/binary"
"testing"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -251,3 +254,89 @@ func TestParseStmtFetchCmd(t *testing.T) {
require.Equal(t, tc.err, err)
}
}

func TestCursorReadHoldTS(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
srv := CreateMockServer(t, store)
srv.SetDomain(dom)
defer srv.Close()

appendUint32 := binary.LittleEndian.AppendUint32
ctx := context.Background()
c := CreateMockConn(t, srv)
tk := testkit.NewTestKitWithSession(t, store, c.Context().Session)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key)")
tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)")
tk.MustQuery("select count(*) from t").Check(testkit.Rows("8"))

stmt, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts := tk.Session().ShowProcess().GetMinStartTS(0)
require.Positive(t, ts)
// should unhold ts when result set exhausted
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0))
// should unhold ts when stmt reset
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtReset}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0))
// should unhold ts when stmt closed
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// create another 2 stmts and execute them
stmt1, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt1.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts1 := tk.Session().ShowProcess().GetMinStartTS(0)
require.Positive(t, ts1)
stmt2, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt2.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts2 := tk.Session().ShowProcess().GetMinStartTS(ts1)
require.Positive(t, ts2)

require.Less(t, ts1, ts2)
require.Equal(t, ts1, srv.GetMinStartTS(0))
require.Equal(t, ts2, srv.GetMinStartTS(ts1))
require.Zero(t, srv.GetMinStartTS(ts2))

// should unhold all when session closed
c.Close()
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))
require.Zero(t, srv.GetMinStartTS(0))
}
40 changes: 40 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,46 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo {
return trs.columns
}

// rsWithHooks wraps a ResultSet with some hooks (currently only onClosed).
type rsWithHooks struct {
ResultSet
onClosed func()
}

// Close implements ResultSet#Close
func (rs *rsWithHooks) Close() error {
closed := rs.IsClosed()
err := rs.ResultSet.Close()
if !closed && rs.onClosed != nil {
rs.onClosed()
}
return err
}

// OnFetchReturned implements fetchNotifier#OnFetchReturned
func (rs *rsWithHooks) OnFetchReturned() {
if impl, ok := rs.ResultSet.(fetchNotifier); ok {
impl.OnFetchReturned()
}
}

// Unwrap returns the underlying result set
func (rs *rsWithHooks) Unwrap() ResultSet {
return rs.ResultSet
}

// unwrapResultSet likes errors.Cause but for ResultSet
func unwrapResultSet(rs ResultSet) ResultSet {
var unRS ResultSet
if u, ok := rs.(interface{ Unwrap() ResultSet }); ok {
unRS = u.Unwrap()
}
if unRS == nil {
return rs
}
return unwrapResultSet(unRS)
}

func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) {
ci = &ColumnInfo{
Name: fld.ColumnAsName.O,
Expand Down
25 changes: 25 additions & 0 deletions server/driver_tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -95,3 +96,27 @@ func TestConvertColumnInfo(t *testing.T) {
colInfo = convertColumnInfo(&resultField)
require.Equal(t, uint32(4), colInfo.ColumnLength)
}

func TestRSWithHooks(t *testing.T) {
closeCount := 0
rs := &rsWithHooks{
ResultSet: &tidbResultSet{recordSet: new(sqlexec.SimpleRecordSet)},
onClosed: func() { closeCount++ },
}
require.Equal(t, 0, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
}

func TestUnwrapRS(t *testing.T) {
var nilRS ResultSet
require.Nil(t, unwrapResultSet(nilRS))
rs0 := new(tidbResultSet)
rs1 := &rsWithHooks{ResultSet: rs0}
rs2 := &rsWithHooks{ResultSet: rs1}
for _, rs := range []ResultSet{rs0, rs1, rs2} {
require.Equal(t, rs0, unwrapResultSet(rs))
}
}
34 changes: 34 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,37 @@ func (s *Server) KillNonFlashbackClusterConn() {
s.Kill(id, false)
}
}

// GetMinStartTS implements SessionManager interface.
func (s *Server) GetMinStartTS(lowerBound uint64) (ts uint64) {
// sys processes
if s.dom != nil {
for _, pi := range s.dom.SysProcTracker().GetSysProcessList() {
if thisTS := pi.GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) {
ts = thisTS
}
}
}
// user sessions
func() {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
for _, client := range s.clients {
if thisTS := client.ctx.ShowProcess().GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) {
ts = thisTS
}
}
}()
// internal sessions
func() {
s.sessionMapMutex.Lock()
defer s.sessionMapMutex.Unlock()
analyzeProcID := util.GetAutoAnalyzeProcID(s.ServerID)
for se := range s.internalSessions {
if thisTS, processInfoID := session.GetStartTSFromSession(se); processInfoID != analyzeProcID && thisTS > lowerBound && (thisTS < ts || ts == 0) {
ts = thisTS
}
}
}()
return
}
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
OOMAlarmVariablesInfo: s.getOomAlarmVariablesInfo(),
MaxExecutionTime: maxExecutionTime,
RedactSQL: s.sessionVars.EnableRedactLog,
ProtectedTSList: &s.sessionVars.ProtectedTSList,
}
oldPi := s.ShowProcess()
if p == nil {
Expand Down
Loading

0 comments on commit 0fe61bd

Please sign in to comment.