diff --git a/pkg/session/session.go b/pkg/session/session.go index 5b0fbbfa72bb9..419ae8d5b65d7 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2625,7 +2625,148 @@ func (s *session) GetSessionVars() *variable.SessionVars { return s.sessionVars } +<<<<<<< HEAD func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { +======= +// GetPlanCtx returns the PlanContext. +func (s *session) GetPlanCtx() planctx.PlanContext { + return s.pctx +} + +// GetExprCtx returns the expression context of the session. +func (s *session) GetExprCtx() exprctx.ExprContext { + return s.exprctx +} + +// GetTableCtx returns the table.MutateContext +func (s *session) GetTableCtx() tblctx.MutateContext { + return s.tblctx +} + +// GetDistSQLCtx returns the context used in DistSQL +func (s *session) GetDistSQLCtx() *distsqlctx.DistSQLContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + dctx := sc.GetOrInitDistSQLFromCache(func() *distsqlctx.DistSQLContext { + return &distsqlctx.DistSQLContext{ + WarnHandler: sc.WarnHandler, + InRestrictedSQL: sc.InRestrictedSQL, + Client: s.GetClient(), + + EnabledRateLimitAction: vars.EnabledRateLimitAction, + EnableChunkRPC: vars.EnableChunkRPC, + OriginalSQL: sc.OriginalSQL, + KVVars: vars.KVVars, + KvExecCounter: sc.KvExecCounter, + SessionMemTracker: vars.MemTracker, + + Location: sc.TimeZone(), + RuntimeStatsColl: sc.RuntimeStatsColl, + SQLKiller: &vars.SQLKiller, + CPUUsage: &vars.SQLCPUUsages, + ErrCtx: sc.ErrCtx(), + + TiFlashReplicaRead: vars.TiFlashReplicaRead, + TiFlashMaxThreads: vars.TiFlashMaxThreads, + TiFlashMaxBytesBeforeExternalJoin: vars.TiFlashMaxBytesBeforeExternalJoin, + TiFlashMaxBytesBeforeExternalGroupBy: vars.TiFlashMaxBytesBeforeExternalGroupBy, + TiFlashMaxBytesBeforeExternalSort: vars.TiFlashMaxBytesBeforeExternalSort, + TiFlashMaxQueryMemoryPerNode: vars.TiFlashMaxQueryMemoryPerNode, + TiFlashQuerySpillRatio: vars.TiFlashQuerySpillRatio, + TiFlashHashJoinVersion: vars.TiFlashHashJoinVersion, + + DistSQLConcurrency: vars.DistSQLScanConcurrency(), + ReplicaReadType: vars.GetReplicaRead(), + WeakConsistency: sc.WeakConsistency, + RCCheckTS: sc.RCCheckTS, + NotFillCache: sc.NotFillCache, + TaskID: sc.TaskID, + Priority: sc.Priority, + ResourceGroupTagger: sc.GetResourceGroupTagger(), + EnablePaging: vars.EnablePaging, + MinPagingSize: vars.MinPagingSize, + MaxPagingSize: vars.MaxPagingSize, + RequestSourceType: vars.RequestSourceType, + ExplicitRequestSourceType: vars.ExplicitRequestSourceType, + StoreBatchSize: vars.StoreBatchSize, + ResourceGroupName: sc.ResourceGroupName, + LoadBasedReplicaReadThreshold: vars.LoadBasedReplicaReadThreshold, + RunawayChecker: sc.RunawayChecker, + TiKVClientReadTimeout: vars.GetTiKVClientReadTimeout(), + MaxExecutionTime: vars.GetMaxExecutionTime(), + + ReplicaClosestReadThreshold: vars.ReplicaClosestReadThreshold, + ConnectionID: vars.ConnectionID, + SessionAlias: vars.SessionAlias, + + ExecDetails: &sc.SyncExecDetails, + } + }) + + // Check if the runaway checker is updated. This is to avoid that evaluating a non-correlated subquery + // during the optimization phase will cause the `*distsqlctx.DistSQLContext` to be created before the + // runaway checker is set later at the execution phase. + // Ref: https://github.com/pingcap/tidb/issues/61899 + if dctx.RunawayChecker != sc.RunawayChecker { + dctx.RunawayChecker = sc.RunawayChecker + } + + return dctx +} + +// GetRangerCtx returns the context used in `ranger` related functions +func (s *session) GetRangerCtx() *rangerctx.RangerContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + rctx := sc.GetOrInitRangerCtxFromCache(func() any { + return &rangerctx.RangerContext{ + ExprCtx: s.GetExprCtx(), + TypeCtx: s.GetSessionVars().StmtCtx.TypeCtx(), + ErrCtx: s.GetSessionVars().StmtCtx.ErrCtx(), + + InPreparedPlanBuilding: s.GetSessionVars().StmtCtx.InPreparedPlanBuilding, + RegardNULLAsPoint: s.GetSessionVars().RegardNULLAsPoint, + OptPrefixIndexSingleScan: s.GetSessionVars().OptPrefixIndexSingleScan, + OptimizerFixControl: s.GetSessionVars().OptimizerFixControl, + + PlanCacheTracker: &s.GetSessionVars().StmtCtx.PlanCacheTracker, + RangeFallbackHandler: &s.GetSessionVars().StmtCtx.RangeFallbackHandler, + } + }) + + return rctx.(*rangerctx.RangerContext) +} + +// GetBuildPBCtx returns the context used in `ToPB` method +func (s *session) GetBuildPBCtx() *planctx.BuildPBContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + bctx := sc.GetOrInitBuildPBCtxFromCache(func() any { + return &planctx.BuildPBContext{ + ExprCtx: s.GetExprCtx(), + Client: s.GetClient(), + + TiFlashFastScan: s.GetSessionVars().TiFlashFastScan, + TiFlashFineGrainedShuffleBatchSize: s.GetSessionVars().TiFlashFineGrainedShuffleBatchSize, + + // the following fields are used to build `expression.PushDownContext`. + // TODO: it'd be better to embed `expression.PushDownContext` in `BuildPBContext`. But `expression` already + // depends on this package, so we need to move `expression.PushDownContext` to a standalone package first. + GroupConcatMaxLen: s.GetSessionVars().GroupConcatMaxLen, + InExplainStmt: s.GetSessionVars().StmtCtx.InExplainStmt, + WarnHandler: s.GetSessionVars().StmtCtx.WarnHandler, + ExtraWarnghandler: s.GetSessionVars().StmtCtx.ExtraWarnHandler, + } + }) + + return bctx.(*planctx.BuildPBContext) +} + +func (s *session) AuthPluginForUser(ctx context.Context, user *auth.UserIdentity) (string, error) { +>>>>>>> 49f4868f59b (fix(runaway): ensure DistSQLContext's checker is synchronized with session variables (#61907)) pm := privilege.GetPrivilegeManager(s) authplugin, err := pm.GetAuthPluginForConnection(user.Username, user.Hostname) if err != nil {