diff --git a/go/cmd/dolt/commands/assist.go b/go/cmd/dolt/commands/assist.go index 04de9739dd5..6393d95ad0a 100644 --- a/go/cmd/dolt/commands/assist.go +++ b/go/cmd/dolt/commands/assist.go @@ -93,6 +93,9 @@ func (a *Assist) Exec(ctx context.Context, commandStr string, args []string, dEn if err != nil { return 1 } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) sqlCtx.SetCurrentDatabase(dbName) a.messages, err = getInitialPrompt(sqlCtx, sqlEng, dEnv) diff --git a/go/cmd/dolt/commands/cvcmds/verify_constraints.go b/go/cmd/dolt/commands/cvcmds/verify_constraints.go index 2d6b10a2c9e..2436f993c5f 100644 --- a/go/cmd/dolt/commands/cvcmds/verify_constraints.go +++ b/go/cmd/dolt/commands/cvcmds/verify_constraints.go @@ -127,9 +127,17 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg if err != nil { return commands.HandleVErrAndExitCode(errhand.BuildDError("Failed to build sql engine.").AddCause(err).Build(), nil) } + sqlCtx, err := eng.NewLocalContext(ctx) + if err != nil { + return commands.HandleVErrAndExitCode(errhand.BuildDError("Failed to build sql context.").AddCause(err).Build(), nil) + } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + sqlCtx.SetCurrentDatabase(dbName) for _, tableName := range tablesWithViolations.AsSortedSlice() { - tbl, ok, err := endRoot.GetTable(ctx, tableName) + tbl, ok, err := endRoot.GetTable(sqlCtx, tableName) if err != nil { return commands.HandleVErrAndExitCode(errhand.BuildDError("Error loading table.").AddCause(err).Build(), nil) } @@ -138,14 +146,14 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg } cli.Println("") cli.Println(doltdb.DoltConstViolTablePrefix + tableName.Name) - dErr := printViolationsForTable(ctx, dbName, tableName.Name, tbl, eng) + dErr := printViolationsForTable(sqlCtx, dbName, tableName.Name, tbl, eng) if dErr != nil { return commands.HandleVErrAndExitCode(dErr, nil) } } if outputOnly { - err = dEnv.UpdateWorkingRoot(ctx, working) + err = dEnv.UpdateWorkingRoot(sqlCtx, working) if err != nil { return commands.HandleVErrAndExitCode(errhand.BuildDError("Unable to undo written constraint violations").AddCause(err).Build(), nil) } @@ -157,7 +165,7 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg return 0 } -func printViolationsForTable(ctx context.Context, dbName, tblName string, tbl *doltdb.Table, eng *engine.SqlEngine) errhand.VerboseError { +func printViolationsForTable(ctx *sql.Context, dbName, tblName string, tbl *doltdb.Table, eng *engine.SqlEngine) errhand.VerboseError { sch, err := tbl.GetSchema(ctx) if err != nil { return errhand.BuildDError("Error loading table schema").AddCause(err).Build() @@ -166,20 +174,14 @@ func printViolationsForTable(ctx context.Context, dbName, tblName string, tbl *d colNames := strings.Join(sch.GetAllCols().GetColumnNames(), ", ") query := fmt.Sprintf("SELECT violation_type, %s, violation_info from dolt_constraint_violations_%s", colNames, tblName) - sCtx, err := eng.NewLocalContext(ctx) - if err != nil { - return errhand.BuildDError("Error making sql context").AddCause(err).Build() - } - sCtx.SetCurrentDatabase(dbName) - - sqlSch, sqlItr, _, err := eng.Query(sCtx, query) + sqlSch, sqlItr, _, err := eng.Query(ctx, query) if err != nil { return errhand.BuildDError("Error querying constraint violations").AddCause(err).Build() } limitItr := &sqlLimitIter{itr: sqlItr, limit: 50} - err = engine.PrettyPrintResults(sCtx, engine.FormatTabular, sqlSch, limitItr, false) + err = engine.PrettyPrintResults(ctx, engine.FormatTabular, sqlSch, limitItr, false) if err != nil { return errhand.BuildDError("Error outputting rows").AddCause(err).Build() } diff --git a/go/cmd/dolt/commands/docscmds/diff.go b/go/cmd/dolt/commands/docscmds/diff.go index fa163f01311..5b177cf8ad0 100644 --- a/go/cmd/dolt/commands/docscmds/diff.go +++ b/go/cmd/dolt/commands/docscmds/diff.go @@ -18,6 +18,7 @@ import ( "context" textdiff "github.com/andreyvit/diff" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" @@ -89,13 +90,21 @@ func diffDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName string) error { if err != nil { return err } + sqlCtx, err := eng.NewLocalContext(ctx) + if err != nil { + return err + } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + sqlCtx.SetCurrentDatabase(dbName) - working, err := readDocFromTable(ctx, eng, dbName, docName) + working, err := readDocFromTable(sqlCtx, eng, docName) if err != nil { return err } - head, err := readDocFromTableAsOf(ctx, eng, dbName, docName, "HEAD") + head, err := readDocFromTableAsOf(sqlCtx, eng, docName, "HEAD") if err != nil { return err } diff --git a/go/cmd/dolt/commands/docscmds/read.go b/go/cmd/dolt/commands/docscmds/read.go index 2382fdb6648..7b0222df4b5 100644 --- a/go/cmd/dolt/commands/docscmds/read.go +++ b/go/cmd/dolt/commands/docscmds/read.go @@ -118,8 +118,16 @@ func readDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName, fileName strin if err != nil { return err } + sqlCtx, err := eng.NewLocalContext(ctx) + if err != nil { + return err + } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + sqlCtx.SetCurrentDatabase(dbName) - err = writeDocToTable(ctx, eng, dbName, docName, string(update)) + err = writeDocToTable(sqlCtx, eng, docName, string(update)) if err != nil { return err } @@ -131,30 +139,21 @@ const ( writeDocTemplate = `REPLACE INTO dolt_docs VALUES ("%s", "%s")` ) -func writeDocToTable(ctx context.Context, eng *engine.SqlEngine, dbName, docName, content string) error { +func writeDocToTable(ctx *sql.Context, eng *engine.SqlEngine, docName, content string) error { var ( - sctx *sql.Context - err error + err error ) - sctx, err = eng.NewDefaultContext(ctx) + err = ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, 1) if err != nil { return err } - - sctx.SetCurrentDatabase(dbName) - - err = sctx.Session.SetSessionVariable(sctx, sql.AutoCommitSessionVar, 1) - if err != nil { - return err - } - - sctx.Session.SetClient(sql.Client{User: "root", Address: "%", Capabilities: 0}) + ctx.Session.SetClient(sql.Client{User: "root", Address: "%", Capabilities: 0}) content = strings.ReplaceAll(content, `"`, `\"`) update := fmt.Sprintf(writeDocTemplate, docName, content) - return execQuery(sctx, eng, update) + return execQuery(ctx, eng, update) } func execQuery(sctx *sql.Context, eng *engine.SqlEngine, q string) (err error) { diff --git a/go/cmd/dolt/commands/docscmds/write.go b/go/cmd/dolt/commands/docscmds/write.go index f9c53d780a0..634446c8e57 100644 --- a/go/cmd/dolt/commands/docscmds/write.go +++ b/go/cmd/dolt/commands/docscmds/write.go @@ -93,8 +93,16 @@ func writeDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName string) error if err != nil { return err } + sqlCtx, err := eng.NewLocalContext(ctx) + if err != nil { + return err + } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + sqlCtx.SetCurrentDatabase(dbName) - doc, err := readDocFromTable(ctx, eng, dbName, docName) + doc, err := readDocFromTable(sqlCtx, eng, docName) if err != nil { return err } @@ -108,13 +116,12 @@ const ( "FROM dolt_docs %s WHERE " + doltdb.DocPkColumnName + " = '%s'" ) -func readDocFromTable(ctx context.Context, eng *engine.SqlEngine, dbName, docName string) (string, error) { - return readDocFromTableAsOf(ctx, eng, dbName, docName, "") +func readDocFromTable(ctx *sql.Context, eng *engine.SqlEngine, docName string) (string, error) { + return readDocFromTableAsOf(ctx, eng, docName, "") } -func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, docName, asOf string) (doc string, err error) { +func readDocFromTableAsOf(ctx *sql.Context, eng *engine.SqlEngine, docName, asOf string) (doc string, err error) { var ( - sctx *sql.Context iter sql.RowIter row sql.Row ) @@ -124,13 +131,7 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do } query := fmt.Sprintf(readDocTemplate, asOf, docName) - sctx, err = eng.NewLocalContext(ctx) - if err != nil { - return "", err - } - sctx.SetCurrentDatabase(dbName) - - _, iter, _, err = eng.Query(sctx, query) + _, iter, _, err = eng.Query(ctx, query) if sql.ErrTableNotFound.Is(err) { return "", errors.New("no dolt docs in this database") } @@ -139,12 +140,12 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do } defer func() { - if cerr := iter.Close(sctx); err == nil { + if cerr := iter.Close(ctx); err == nil { err = cerr } }() - row, err = iter.Next(sctx) + row, err = iter.Next(ctx) if err == io.EOF { // doc does not exist return "", nil @@ -155,7 +156,7 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do doc = row[0].(string) - _, eof := iter.Next(sctx) + _, eof := iter.Next(ctx) if eof != io.EOF && eof != nil { return "", eof } diff --git a/go/libraries/doltcore/merge/schema_merge_test.go b/go/libraries/doltcore/merge/schema_merge_test.go index 38ae03d30b5..3d2d98eeb07 100644 --- a/go/libraries/doltcore/merge/schema_merge_test.go +++ b/go/libraries/doltcore/merge/schema_merge_test.go @@ -1836,6 +1836,9 @@ func sch(definition string) namedSchema { root, _ := doltdb.EmptyRootValue(ctx, vrw, ns) eng, dbName, _ := engine.NewSqlEngineForEnv(ctx, denv) sqlCtx, _ := eng.NewDefaultContext(ctx) + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) sqlCtx.SetCurrentDatabase(dbName) // TODO: ParseCreateTableStatement silently drops any indexes or check constraints in the definition name, s, err := sqlutil.ParseCreateTableStatement(sqlCtx, root, eng.GetUnderlyingEngine(), definition) diff --git a/go/libraries/doltcore/schema/encoding/integration_test.go b/go/libraries/doltcore/schema/encoding/integration_test.go index 935ea5471db..4fc4da624ec 100644 --- a/go/libraries/doltcore/schema/encoding/integration_test.go +++ b/go/libraries/doltcore/schema/encoding/integration_test.go @@ -19,6 +19,7 @@ import ( "strings" "testing" + "github.com/dolthub/go-mysql-server/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,6 +79,9 @@ func parseSchemaString(t *testing.T, s string) schema.Schema { require.NoError(t, err) sqlCtx, err := eng.NewDefaultContext(ctx) require.NoError(t, err) + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) sqlCtx.SetCurrentDatabase(db) _, sch, err := sqlutil.ParseCreateTableStatement(sqlCtx, root, eng.GetUnderlyingEngine(), s) require.NoError(t, err)