Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go/cmd/dolt/commands/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ func (a *Assist) Exec(ctx context.Context, commandStr string, args []string, dEn
if err != nil {
return 1
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

a.messages, err = getInitialPrompt(sqlCtx, sqlEng, dEnv)
Expand Down
26 changes: 14 additions & 12 deletions go/cmd/dolt/commands/cvcmds/verify_constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,17 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg
if err != nil {
return commands.HandleVErrAndExitCode(errhand.BuildDError("Failed to build sql engine.").AddCause(err).Build(), nil)
}
sqlCtx, err := eng.NewLocalContext(ctx)
if err != nil {
return commands.HandleVErrAndExitCode(errhand.BuildDError("Failed to build sql context.").AddCause(err).Build(), nil)
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

for _, tableName := range tablesWithViolations.AsSortedSlice() {
tbl, ok, err := endRoot.GetTable(ctx, tableName)
tbl, ok, err := endRoot.GetTable(sqlCtx, tableName)
if err != nil {
return commands.HandleVErrAndExitCode(errhand.BuildDError("Error loading table.").AddCause(err).Build(), nil)
}
Expand All @@ -138,14 +146,14 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg
}
cli.Println("")
cli.Println(doltdb.DoltConstViolTablePrefix + tableName.Name)
dErr := printViolationsForTable(ctx, dbName, tableName.Name, tbl, eng)
dErr := printViolationsForTable(sqlCtx, dbName, tableName.Name, tbl, eng)
if dErr != nil {
return commands.HandleVErrAndExitCode(dErr, nil)
}
}

if outputOnly {
err = dEnv.UpdateWorkingRoot(ctx, working)
err = dEnv.UpdateWorkingRoot(sqlCtx, working)
if err != nil {
return commands.HandleVErrAndExitCode(errhand.BuildDError("Unable to undo written constraint violations").AddCause(err).Build(), nil)
}
Expand All @@ -157,7 +165,7 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg
return 0
}

func printViolationsForTable(ctx context.Context, dbName, tblName string, tbl *doltdb.Table, eng *engine.SqlEngine) errhand.VerboseError {
func printViolationsForTable(ctx *sql.Context, dbName, tblName string, tbl *doltdb.Table, eng *engine.SqlEngine) errhand.VerboseError {
sch, err := tbl.GetSchema(ctx)
if err != nil {
return errhand.BuildDError("Error loading table schema").AddCause(err).Build()
Expand All @@ -166,20 +174,14 @@ func printViolationsForTable(ctx context.Context, dbName, tblName string, tbl *d
colNames := strings.Join(sch.GetAllCols().GetColumnNames(), ", ")
query := fmt.Sprintf("SELECT violation_type, %s, violation_info from dolt_constraint_violations_%s", colNames, tblName)

sCtx, err := eng.NewLocalContext(ctx)
if err != nil {
return errhand.BuildDError("Error making sql context").AddCause(err).Build()
}
sCtx.SetCurrentDatabase(dbName)

sqlSch, sqlItr, _, err := eng.Query(sCtx, query)
sqlSch, sqlItr, _, err := eng.Query(ctx, query)
if err != nil {
return errhand.BuildDError("Error querying constraint violations").AddCause(err).Build()
}

limitItr := &sqlLimitIter{itr: sqlItr, limit: 50}

err = engine.PrettyPrintResults(sCtx, engine.FormatTabular, sqlSch, limitItr, false)
err = engine.PrettyPrintResults(ctx, engine.FormatTabular, sqlSch, limitItr, false)
if err != nil {
return errhand.BuildDError("Error outputting rows").AddCause(err).Build()
}
Expand Down
13 changes: 11 additions & 2 deletions go/cmd/dolt/commands/docscmds/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"

textdiff "github.com/andreyvit/diff"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
Expand Down Expand Up @@ -89,13 +90,21 @@ func diffDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName string) error {
if err != nil {
return err
}
sqlCtx, err := eng.NewLocalContext(ctx)
if err != nil {
return err
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

working, err := readDocFromTable(ctx, eng, dbName, docName)
working, err := readDocFromTable(sqlCtx, eng, docName)
if err != nil {
return err
}

head, err := readDocFromTableAsOf(ctx, eng, dbName, docName, "HEAD")
head, err := readDocFromTableAsOf(sqlCtx, eng, docName, "HEAD")
if err != nil {
return err
}
Expand Down
29 changes: 14 additions & 15 deletions go/cmd/dolt/commands/docscmds/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,16 @@ func readDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName, fileName strin
if err != nil {
return err
}
sqlCtx, err := eng.NewLocalContext(ctx)
if err != nil {
return err
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

err = writeDocToTable(ctx, eng, dbName, docName, string(update))
err = writeDocToTable(sqlCtx, eng, docName, string(update))
if err != nil {
return err
}
Expand All @@ -131,30 +139,21 @@ const (
writeDocTemplate = `REPLACE INTO dolt_docs VALUES ("%s", "%s")`
)

func writeDocToTable(ctx context.Context, eng *engine.SqlEngine, dbName, docName, content string) error {
func writeDocToTable(ctx *sql.Context, eng *engine.SqlEngine, docName, content string) error {
var (
sctx *sql.Context
err error
err error
)

sctx, err = eng.NewDefaultContext(ctx)
err = ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, 1)
if err != nil {
return err
}

sctx.SetCurrentDatabase(dbName)

err = sctx.Session.SetSessionVariable(sctx, sql.AutoCommitSessionVar, 1)
if err != nil {
return err
}

sctx.Session.SetClient(sql.Client{User: "root", Address: "%", Capabilities: 0})
ctx.Session.SetClient(sql.Client{User: "root", Address: "%", Capabilities: 0})

content = strings.ReplaceAll(content, `"`, `\"`)
update := fmt.Sprintf(writeDocTemplate, docName, content)

return execQuery(sctx, eng, update)
return execQuery(ctx, eng, update)
}

func execQuery(sctx *sql.Context, eng *engine.SqlEngine, q string) (err error) {
Expand Down
31 changes: 16 additions & 15 deletions go/cmd/dolt/commands/docscmds/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,16 @@ func writeDoltDoc(ctx context.Context, dEnv *env.DoltEnv, docName string) error
if err != nil {
return err
}
sqlCtx, err := eng.NewLocalContext(ctx)
if err != nil {
return err
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

doc, err := readDocFromTable(ctx, eng, dbName, docName)
doc, err := readDocFromTable(sqlCtx, eng, docName)
if err != nil {
return err
}
Expand All @@ -108,13 +116,12 @@ const (
"FROM dolt_docs %s WHERE " + doltdb.DocPkColumnName + " = '%s'"
)

func readDocFromTable(ctx context.Context, eng *engine.SqlEngine, dbName, docName string) (string, error) {
return readDocFromTableAsOf(ctx, eng, dbName, docName, "")
func readDocFromTable(ctx *sql.Context, eng *engine.SqlEngine, docName string) (string, error) {
return readDocFromTableAsOf(ctx, eng, docName, "")
}

func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, docName, asOf string) (doc string, err error) {
func readDocFromTableAsOf(ctx *sql.Context, eng *engine.SqlEngine, docName, asOf string) (doc string, err error) {
var (
sctx *sql.Context
iter sql.RowIter
row sql.Row
)
Expand All @@ -124,13 +131,7 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do
}
query := fmt.Sprintf(readDocTemplate, asOf, docName)

sctx, err = eng.NewLocalContext(ctx)
if err != nil {
return "", err
}
sctx.SetCurrentDatabase(dbName)

_, iter, _, err = eng.Query(sctx, query)
_, iter, _, err = eng.Query(ctx, query)
if sql.ErrTableNotFound.Is(err) {
return "", errors.New("no dolt docs in this database")
}
Expand All @@ -139,12 +140,12 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do
}

defer func() {
if cerr := iter.Close(sctx); err == nil {
if cerr := iter.Close(ctx); err == nil {
err = cerr
}
}()

row, err = iter.Next(sctx)
row, err = iter.Next(ctx)
if err == io.EOF {
// doc does not exist
return "", nil
Expand All @@ -155,7 +156,7 @@ func readDocFromTableAsOf(ctx context.Context, eng *engine.SqlEngine, dbName, do

doc = row[0].(string)

_, eof := iter.Next(sctx)
_, eof := iter.Next(ctx)
if eof != io.EOF && eof != nil {
return "", eof
}
Expand Down
3 changes: 3 additions & 0 deletions go/libraries/doltcore/merge/schema_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,9 @@ func sch(definition string) namedSchema {
root, _ := doltdb.EmptyRootValue(ctx, vrw, ns)
eng, dbName, _ := engine.NewSqlEngineForEnv(ctx, denv)
sqlCtx, _ := eng.NewDefaultContext(ctx)
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)
// TODO: ParseCreateTableStatement silently drops any indexes or check constraints in the definition
name, s, err := sqlutil.ParseCreateTableStatement(sqlCtx, root, eng.GetUnderlyingEngine(), definition)
Expand Down
4 changes: 4 additions & 0 deletions go/libraries/doltcore/schema/encoding/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"
"testing"

"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -78,6 +79,9 @@ func parseSchemaString(t *testing.T, s string) schema.Schema {
require.NoError(t, err)
sqlCtx, err := eng.NewDefaultContext(ctx)
require.NoError(t, err)
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(db)
_, sch, err := sqlutil.ParseCreateTableStatement(sqlCtx, root, eng.GetUnderlyingEngine(), s)
require.NoError(t, err)
Expand Down
Loading