Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't reuse the Queryist sqlCtx #8478

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions go/cmd/dolt/cli/cli_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cli
import (
"context"
"errors"
"fmt"

"github.com/dolthub/go-mysql-server/sql"

Expand All @@ -35,7 +36,7 @@ import (
// connected to a remote server.
type LateBindQueryist func(ctx context.Context) (Queryist, *sql.Context, func(), error)

// CliContexct is used to pass top level command information down to subcommands.
// CliContext is used to pass top level command information down to subcommands.
type CliContext interface {
// GlobalArgs returns the arguments passed before the subcommand.
GlobalArgs() *argparser.ArgParseResults
Expand All @@ -49,21 +50,16 @@ func NewCliContext(args *argparser.ArgParseResults, config *env.DoltCliConfig, l
return nil, errhand.VerboseErrorFromError(errors.New("Invariant violated. args, config, and latebind must be non nil."))
}

return LateBindCliContext{globalArgs: args, config: config, activeContext: &QueryistContext{}, bind: latebind}, nil
}

type QueryistContext struct {
sqlCtx *sql.Context
qryist *Queryist
return LateBindCliContext{globalArgs: args, config: config, activeQueryist: nil, bind: latebind}, nil
}

// LateBindCliContext is a struct that implements CliContext. Its primary purpose is to wrap the global arguments and
// provide an implementation of the QueryEngine function. This instance is stateful to ensure that the Queryist is only
// created once.
type LateBindCliContext struct {
globalArgs *argparser.ArgParseResults
config *env.DoltCliConfig
activeContext *QueryistContext
globalArgs *argparser.ArgParseResults
config *env.DoltCliConfig
activeQueryist *Queryist

bind LateBindQueryist
}
Expand All @@ -77,17 +73,21 @@ func (lbc LateBindCliContext) GlobalArgs() *argparser.ArgParseResults {
// LateBindQueryist is made, and caches the result. Note that if this is called twice, the closer function returns will
// be nil, callers should check if is nil.
func (lbc LateBindCliContext) QueryEngine(ctx context.Context) (Queryist, *sql.Context, func(), error) {
if lbc.activeContext != nil && lbc.activeContext.qryist != nil && lbc.activeContext.sqlCtx != nil {
return *lbc.activeContext.qryist, lbc.activeContext.sqlCtx, nil, nil
if lbc.activeQueryist != nil {
sqlCtx, ok := ctx.(*sql.Context)
if !ok {
return nil, nil, nil, errors.New(fmt.Sprintf("type coercion failed. Require *sql.Context, got: %T", ctx))
}

return *lbc.activeQueryist, sqlCtx, nil, nil
}

qryist, sqlCtx, closer, err := lbc.bind(ctx)
if err != nil {
return nil, nil, nil, err
}

lbc.activeContext.qryist = &qryist
lbc.activeContext.sqlCtx = sqlCtx
lbc.activeQueryist = &qryist

return qryist, sqlCtx, closer, nil
}
Expand Down
12 changes: 6 additions & 6 deletions go/cmd/dolt/commands/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
subCtx, stop := signal.NotifyContext(initialCtx, os.Interrupt, syscall.SIGTERM)
defer stop()

sqlCtx := sql.NewContext(subCtx, sql.WithSession(sqlCtx.Session))
subSqlCtx := sql.NewContext(subCtx, sql.WithSession(sqlCtx.Session))

cmdType, subCmd, newQuery, err := preprocessQuery(query, lastSqlCmd, cliCtx)
if err != nil {
Expand All @@ -787,7 +787,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}

if cmdType == DoltCliCommand {
err := handleSlashCommand(sqlCtx, subCmd, query, cliCtx)
err := handleSlashCommand(subSqlCtx, subCmd, query, cliCtx)
if err != nil {
shell.Println(color.RedString(err.Error()))
}
Expand All @@ -799,15 +799,15 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
lastSqlCmd = query
var sqlSch sql.Schema
var rowIter sql.RowIter
if sqlSch, rowIter, _, err = processQuery(sqlCtx, query, qryist); err != nil {
if sqlSch, rowIter, _, err = processQuery(subSqlCtx, query, qryist); err != nil {
verr := formatQueryError("", err)
shell.Println(verr.Verbose())
} else if rowIter != nil {
switch closureFormat {
case engine.FormatTabular, engine.FormatVertical:
err = engine.PrettyPrintResultsExtended(sqlCtx, closureFormat, sqlSch, rowIter)
err = engine.PrettyPrintResultsExtended(subSqlCtx, closureFormat, sqlSch, rowIter)
default:
err = engine.PrettyPrintResults(sqlCtx, closureFormat, sqlSch, rowIter)
err = engine.PrettyPrintResults(subSqlCtx, closureFormat, sqlSch, rowIter)
}

if err != nil {
Expand All @@ -816,7 +816,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}
}

nextPrompt, multiPrompt = postCommandUpdate(sqlCtx, qryist)
nextPrompt, multiPrompt = postCommandUpdate(subSqlCtx, qryist)

return true
}()
Expand Down
Loading