Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/rebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func insertRebasePlanIntoDoltRebaseTable(plan *rebase.RebasePlan, sqlCtx *sql.Co
// to be back on the branch being rebased (e.g. t1).
func syncCliBranchToSqlSessionBranch(ctx *sql.Context, dEnv *env.DoltEnv) error {
doltSession := dsess.DSessFromSess(ctx.Session)
currentBranch, err := doltSession.GetBranch()
currentBranch, err := doltSession.GetBranch(ctx)
if err != nil {
return err
}
Expand Down
54 changes: 53 additions & 1 deletion go/libraries/doltcore/branch_control/branch_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,12 @@ func AddAdminForContext(ctx context.Context, branchName string) error {
// then nil is returned.
func GetBranchAwareSession(ctx context.Context) Context {
if sqlCtx, ok := ctx.(*sql.Context); ok {
if bas, ok := sqlCtx.Session.(Context); ok {
if bas, ok := sqlCtx.Session.(ContextConvertible); ok {
return &convertedContext{
convertible: bas,
ctx: sqlCtx,
}
} else if bas, ok := sqlCtx.Session.(Context); ok {
return bas
}
} else if bas, ok := ctx.(Context); ok {
Expand All @@ -357,6 +362,53 @@ func GetBranchAwareSession(ctx context.Context) Context {
return nil
}

// GetBranchAwareSession will return a context converted from this if
// the Session supports this GetBranch call instead.
type ContextConvertible interface {
GetBranch(*sql.Context) (string, error)
GetCurrentDatabase() string
GetUser() string
GetHost() string
GetPrivilegeSet() (sql.PrivilegeSet, uint64)
GetController() *Controller
GetFileSystem() filesys.Filesys
}

type convertedContext struct {
convertible ContextConvertible
ctx *sql.Context
}

var _ Context = (*convertedContext)(nil)

func (cc *convertedContext) GetBranch() (string, error) {
return cc.convertible.GetBranch(cc.ctx)
}

func (cc *convertedContext) GetCurrentDatabase() string {
return cc.convertible.GetCurrentDatabase()
}

func (cc *convertedContext) GetUser() string {
return cc.convertible.GetUser()
}

func (cc *convertedContext) GetHost() string {
return cc.convertible.GetHost()
}

func (cc *convertedContext) GetPrivilegeSet() (sql.PrivilegeSet, uint64) {
return cc.convertible.GetPrivilegeSet()
}

func (cc *convertedContext) GetController() *Controller {
return cc.convertible.GetController()
}

func (cc *convertedContext) GetFileSystem() filesys.Filesys {
return cc.convertible.GetFileSystem()
}

// HasDatabasePrivileges returns whether the given context's user has the correct privileges to modify any table entries
// that match the given database. The following are the required privileges:
//
Expand Down
2 changes: 1 addition & 1 deletion go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func checkoutTablesFromCommit(
return fmt.Errorf("Could not load database %s", ctx.GetCurrentDatabase())
}

currentBranch, err := dSess.GetBranch()
currentBranch, err := dSess.GetBranch(ctx)
if err != nil {
return err
}
Expand Down
6 changes: 2 additions & 4 deletions go/libraries/doltcore/sqle/dsess/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type DoltSession struct {
var _ sql.Session = (*DoltSession)(nil)
var _ sql.PersistableSession = (*DoltSession)(nil)
var _ sql.TransactionSession = (*DoltSession)(nil)
var _ branch_control.Context = (*DoltSession)(nil)
var _ branch_control.ContextConvertible = (*DoltSession)(nil)

// DefaultSession creates a DoltSession with default values
func DefaultSession(pro DoltDatabaseProvider, sessFunc WriteSessFunc) *DoltSession {
Expand Down Expand Up @@ -1703,9 +1703,7 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
}

// GetBranch implements the interface branch_control.Context.
func (d *DoltSession) GetBranch() (string, error) {
// TODO: creating a new SQL context here is expensive
ctx := sql.NewContext(context.Background(), sql.WithSession(d))
func (d *DoltSession) GetBranch(ctx *sql.Context) (string, error) {
currentDb := d.Session.GetCurrentDatabase()

// no branch if there's no current db
Expand Down
4 changes: 2 additions & 2 deletions go/libraries/doltcore/sqle/statspro/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (sc *StatsController) AnalyzeTable(ctx *sql.Context, table sql.Table, dbNam
}
if branch == "" {
var err error
branch, err = dSess.GetBranch()
branch, err = dSess.GetBranch(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -411,7 +411,7 @@ func (sc *StatsController) DropDbStats(ctx *sql.Context, dbName string, flush bo

func (sc *StatsController) statsKey(ctx *sql.Context, dbName, table string) (tableIndexesKey, error) {
dSess := dsess.DSessFromSess(ctx.Session)
branch, err := dSess.GetBranch()
branch, err := dSess.GetBranch(ctx)
if err != nil {
return tableIndexesKey{}, err
}
Expand Down
Loading