Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ func newDatabase() (*sql2.DB, func()) {
Protocol: "tcp",
Address: fmt.Sprintf("localhost:%d", port),
}
srv, err := server.NewServer(cfg, engine, harness.SessionBuilder(), nil)
srv, err := server.NewServer(cfg, engine, sql.NewContext, harness.SessionBuilder(), nil)
if err != nil {
panic(err)
}
Expand Down
6 changes: 3 additions & 3 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -2142,7 +2142,7 @@ func TestUserAuthentication(t *testing.T, h Harness) {
require.FailNow(t, "harness must implement ServerHarness")
}

s, err := server.NewServer(serverConfig, engine, serverHarness.SessionBuilder(), nil)
s, err := server.NewServer(serverConfig, engine, sql.NewContext, serverHarness.SessionBuilder(), nil)
require.NoError(t, err)
go func() {
err := s.Start()
Expand Down Expand Up @@ -5695,7 +5695,7 @@ func testCharsetCollationWire(t *testing.T, h Harness, sessionBuilder server.Ses
defer engine.Close()
engine.EngineAnalyzer().Catalog.MySQLDb.AddRootAccount()

s, err := server.NewServer(serverConfig, engine, sessionBuilder, nil)
s, err := server.NewServer(serverConfig, engine, sql.NewContext, sessionBuilder, nil)
require.NoError(t, err)
go func() {
err := s.Start()
Expand Down Expand Up @@ -5811,7 +5811,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
Address: fmt.Sprintf("localhost:%d", port),
MaxConnections: 1000,
}
s, err := server.NewServer(serverConfig, engine, sessionBuilder, nil)
s, err := server.NewServer(serverConfig, engine, sql.NewContext, sessionBuilder, nil)
require.NoError(t, err)
go func() {
err := s.Start()
Expand Down
2 changes: 1 addition & 1 deletion enginetest/server_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewServerQueryEngine(t *testing.T, engine *sqle.Engine, builder server.Sess
Protocol: "tcp",
Address: fmt.Sprintf("%s:%d", address, p),
}
s, err := server.NewServer(config, engine, builder, nil)
s, err := server.NewServer(config, engine, sql.NewContext, builder, nil)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion enginetest/server_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func initTestServer(port int) (*server.Server, error) {
sessBuilder := func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return memory.NewSession(sql.NewBaseSession(), pro), nil
}
s, err := server.NewServer(config, engine, sessBuilder, nil)
s, err := server.NewServer(config, engine, sql.NewContext, sessBuilder, nil)
if err != nil {
return nil, err
}
Expand Down
25 changes: 11 additions & 14 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ type SessionManager struct {
sessions map[uint32]sql.Session
connections map[uint32]*mysql.Conn
lastPid uint64
ctxFactory sql.ContextFactory
}

// NewSessionManager creates a SessionManager with the given SessionBuilder.
func NewSessionManager(
ctxFactory sql.ContextFactory,
builder SessionBuilder,
tracer trace.Tracer,
getDbFunc func(ctx *sql.Context, db string) (sql.Database, error),
Expand All @@ -69,6 +71,7 @@ func NewSessionManager(
builder: builder,
sessions: make(map[uint32]sql.Session),
connections: make(map[uint32]*mysql.Conn),
ctxFactory: ctxFactory,
}
}

Expand Down Expand Up @@ -125,28 +128,27 @@ func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error

// SetDB sets the current database of the given connection session.
// If the session does not exist, it creates a new session with given connection.
func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
sess, err := s.getOrCreateSession(context.Background(), conn)
func (s *SessionManager) SetDB(ctx context.Context, conn *mysql.Conn, dbName string) error {
sess, err := s.getOrCreateSession(ctx, conn)
if err != nil {
return err
}

err = sql.SessionCommandBegin(sess)
if err != nil {
sql.SessionEnd(sess)
return err
}
defer sql.SessionCommandEnd(sess)

ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
ctx, err = s.processlist.BeginOperation(ctx)
sqlCtx := s.ctxFactory(ctx, sql.WithSession(sess))
sqlCtx, err = s.processlist.BeginOperation(sqlCtx)
if err != nil {
return err
}
defer s.processlist.EndOperation(ctx)
defer s.processlist.EndOperation(sqlCtx)
var db sql.Database
if dbName != "" {
db, err = s.getDbFunc(ctx, dbName)
db, err = s.getDbFunc(sqlCtx, dbName)
if err != nil {
return err
}
Expand All @@ -157,7 +159,7 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
db = pdb.Unwrap()
}
err = sess.UseDatabase(ctx, db)
err = sess.UseDatabase(sqlCtx, db)
if err != nil {
return err
}
Expand Down Expand Up @@ -200,11 +202,6 @@ func (s *SessionManager) session(conn *mysql.Conn) sql.Session {
return s.sessions[conn.ConnectionID]
}

// NewContext creates a new context for the session at the given conn.
func (s *SessionManager) NewContext(ctx context.Context, conn *mysql.Conn, query string) (*sql.Context, error) {
return s.NewContextWithQuery(ctx, conn, query)
}

func (s *SessionManager) getOrCreateSession(ctx context.Context, conn *mysql.Conn) (sql.Session, error) {
s.mu.Lock()
sess, ok := s.sessions[conn.ConnectionID]
Expand Down Expand Up @@ -236,7 +233,7 @@ func (s *SessionManager) NewContextWithQuery(ctx context.Context, conn *mysql.Co

ctx, span := s.tracer.Start(ctx, "query")

context := sql.NewContext(
context := s.ctxFactory(
ctx,
sql.WithSession(sess),
sql.WithTracer(s.tracer),
Expand Down
17 changes: 7 additions & 10 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ type Handler struct {
var _ mysql.Handler = (*Handler)(nil)
var _ mysql.ExtendedHandler = (*Handler)(nil)
var _ mysql.BinlogReplicaHandler = (*Handler)(nil)
var _ sql.ContextProvider = (*Handler)(nil)

// NewConnection reports that a new connection has been established.
func (h *Handler) NewConnection(c *mysql.Conn) {
Expand All @@ -103,7 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {

func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
// SetDB itself handles session and processlist operation lifecycle callbacks.
err := h.sm.SetDB(c, schemaName)
err := h.sm.SetDB(context.Background(), c, schemaName)
if err != nil {
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
err = sql.CastSQLError(err)
Expand Down Expand Up @@ -202,10 +201,6 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
return analyzed, fields, nil
}

func (h *Handler) NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) {
return h.sm.NewContext(ctx, c, query)
}

func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, prepare *mysql.PrepareData) (mysql.BoundQuery, []*querypb.Field, error) {
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
Expand Down Expand Up @@ -273,17 +268,19 @@ func (h *Handler) ComResetConnection(c *mysql.Conn) error {
h.maybeReleaseAllLocks(c)
h.e.CloseSession(c.ConnectionID)

ctx := context.Background()

// Create a new session and set the current database
err := h.sm.NewSession(context.Background(), c)
err := h.sm.NewSession(ctx, c)
if err != nil {
return err
}

return h.sm.SetDB(c, db)
return h.sm.SetDB(ctx, c, db)
}

func (h *Handler) ParserOptionsForConnection(c *mysql.Conn) (sqlparser.ParserOptions, error) {
ctx, err := h.sm.NewContext(context.Background(), c, "")
ctx, err := h.sm.NewContextWithQuery(context.Background(), c, "")
if err != nil {
return sqlparser.ParserOptions{}, err
}
Expand Down Expand Up @@ -406,7 +403,7 @@ func (h *Handler) doQuery(
qFlags *sql.QueryFlags,
) (remainder string, err error) {
var sqlCtx *sql.Context
sqlCtx, err = h.sm.NewContext(ctx, c, query)
sqlCtx, err = h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
return "", err
}
Expand Down
Loading
Loading