Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 27 additions & 53 deletions go/cmd/dolt/commands/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"path/filepath"
"strings"

sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/planbuilder"
Expand Down Expand Up @@ -142,6 +143,20 @@ func (cmd DumpCmd) Exec(ctx context.Context, commandStr string, args []string, d
return HandleVErrAndExitCode(vErr, usage)
}

engine, dbName, berr := engine.NewSqlEngineForEnv(ctx, dEnv)
if berr != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(berr), usage)
}
defer engine.Close()
sqlCtx, berr := engine.NewLocalContext(ctx)
if berr != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(berr), usage)
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

switch resFormat {
case emptyFileExt, sqlFileExt:
var defaultName string
Expand All @@ -166,10 +181,6 @@ func (cmd DumpCmd) Exec(ctx context.Context, commandStr string, args []string, d
}

if !apr.Contains(noCreateDbFlag) {
dbName, err := getActiveDatabaseName(ctx, dEnv)
if err != nil {
return HandleVErrAndExitCode(err, usage)
}
err = addCreateDatabaseHeader(dEnv, fPath, dbName)
if err != nil {
return HandleVErrAndExitCode(err, usage)
Expand All @@ -183,18 +194,18 @@ func (cmd DumpCmd) Exec(ctx context.Context, commandStr string, args []string, d

for _, tbl := range tblNames {
tblOpts := newTableArgs(tbl, dumpOpts.dest, !apr.Contains(noBatchFlag), apr.Contains(noAutocommitFlag), schemaOnly)
err = dumpTable(ctx, dEnv, tblOpts, fPath)
err = dumpTable(sqlCtx, dEnv, engine.GetUnderlyingEngine(), root, tblOpts, fPath)
if err != nil {
return HandleVErrAndExitCode(err, usage)
}
}

err = dumpSchemaElements(ctx, dEnv, fPath)
err = dumpSchemaElements(sqlCtx, engine, root, dEnv.FS, fPath)
if err != nil {
return HandleVErrAndExitCode(err, usage)
}
case csvFileExt, jsonFileExt, parquetFileExt:
err = dumpNonSqlTables(ctx, root, dEnv, force, tblNames, resFormat, outputFileOrDirName, false)
err = dumpNonSqlTables(sqlCtx, engine.GetUnderlyingEngine(), root, dEnv, force, tblNames, resFormat, outputFileOrDirName, false)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
Expand All @@ -208,39 +219,23 @@ func (cmd DumpCmd) Exec(ctx context.Context, commandStr string, args []string, d
}

// dumpSchemaElements writes the non-table schema elements (views, triggers, procedures) to the file path given
func dumpSchemaElements(ctx context.Context, dEnv *env.DoltEnv, path string) errhand.VerboseError {
writer, err := dEnv.FS.OpenForWriteAppend(path, os.ModePerm)
func dumpSchemaElements(ctx *sql.Context, eng *engine.SqlEngine, root doltdb.RootValue, fs filesys.Filesys, path string) errhand.VerboseError {
writer, err := fs.OpenForWriteAppend(path, os.ModePerm)
if err != nil {
return errhand.VerboseErrorFromError(err)
}

engine, dbName, err := engine.NewSqlEngineForEnv(ctx, dEnv)
err = dumpViews(ctx, eng, root, writer)
if err != nil {
return errhand.VerboseErrorFromError(err)
}

sqlCtx, err := engine.NewLocalContext(ctx)
err = dumpTriggers(ctx, eng, root, writer)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
sqlCtx.SetCurrentDatabase(dbName)

root, err := dEnv.WorkingRoot(ctx)
if err != nil {
return errhand.VerboseErrorFromError(err)
}

err = dumpViews(sqlCtx, engine, root, writer)
if err != nil {
return errhand.VerboseErrorFromError(err)
}

err = dumpTriggers(sqlCtx, engine, root, writer)
if err != nil {
return errhand.VerboseErrorFromError(err)
}

err = dumpProcedures(sqlCtx, engine, root, writer)
err = dumpProcedures(ctx, eng, root, writer)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
Expand Down Expand Up @@ -560,8 +555,8 @@ func (m dumpOptions) DumpDestName() string {
}

// dumpTable dumps table in file given specific table and file location info
func dumpTable(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, filePath string) errhand.VerboseError {
rd, err := mvdata.NewSqlEngineReader(ctx, dEnv, tblOpts.tableName)
func dumpTable(ctx *sql.Context, dEnv *env.DoltEnv, engine *sqle.Engine, root doltdb.RootValue, tblOpts *tableOptions, filePath string) errhand.VerboseError {
rd, err := mvdata.NewSqlEngineReader(ctx, engine, root, tblOpts.tableName)
if err != nil {
return errhand.BuildDError("Error creating reader for %s.", tblOpts.SrcName()).AddCause(err).Build()
}
Expand Down Expand Up @@ -738,7 +733,7 @@ func newTableArgs(tblName string, destination mvdata.DataLocation, batched, auto

// dumpNonSqlTables returns nil if all tables is dumped successfully, and it returns err if there is one.
// It handles only csv and json file types(rf).
func dumpNonSqlTables(ctx context.Context, root doltdb.RootValue, dEnv *env.DoltEnv, force bool, tblNames []string, rf string, dirName string, batched bool) errhand.VerboseError {
func dumpNonSqlTables(ctx *sql.Context, engine *sqle.Engine, root doltdb.RootValue, dEnv *env.DoltEnv, force bool, tblNames []string, rf string, dirName string, batched bool) errhand.VerboseError {
var fName string
if dirName == emptyStr {
dirName = "doltdump/"
Expand All @@ -759,7 +754,7 @@ func dumpNonSqlTables(ctx context.Context, root doltdb.RootValue, dEnv *env.Dolt

tblOpts := newTableArgs(tbl, dumpOpts.dest, batched, false, false)

err = dumpTable(ctx, dEnv, tblOpts, fPath)
err = dumpTable(ctx, dEnv, engine, root, tblOpts, fPath)
if err != nil {
return err
}
Expand Down Expand Up @@ -809,24 +804,3 @@ func addCreateDatabaseHeader(dEnv *env.DoltEnv, fPath, dbName string) errhand.Ve

return nil
}

// TODO: find a more elegant way to get database name, possibly implement a method in DoltEnv
// getActiveDatabaseName returns the name of the current active database
func getActiveDatabaseName(ctx context.Context, dEnv *env.DoltEnv) (string, errhand.VerboseError) {
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
if err != nil {
return "", errhand.VerboseErrorFromError(err)
}

// Choose the first DB as the current one. This will be the DB in the working dir if there was one there
var dbName string
err = mrEnv.Iter(func(name string, _ *env.DoltEnv) (stop bool, err error) {
dbName = name
return true, nil
})
if err != nil {
return "", errhand.VerboseErrorFromError(err)
}

return dbName, nil
}
17 changes: 12 additions & 5 deletions go/cmd/dolt/commands/engine/sqlengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,21 +470,28 @@ func doltSessionFactory(pro *dsqle.DoltDatabaseProvider, statsPro sql.StatsProvi
}
}

type ConfigOption func(*SqlEngineConfig)

// NewSqlEngineForEnv returns a SqlEngine configured for the environment provided, with a single root user.
// Returns the new engine, the first database name, and any error that occurred.
func NewSqlEngineForEnv(ctx context.Context, dEnv *env.DoltEnv) (*SqlEngine, string, error) {
func NewSqlEngineForEnv(ctx context.Context, dEnv *env.DoltEnv, options ...ConfigOption) (*SqlEngine, string, error) {
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
if err != nil {
return nil, "", err
}

config := &SqlEngineConfig{
ServerUser: "root",
ServerHost: "localhost",
}
for _, opt := range options {
opt(config)
}

engine, err := NewSqlEngine(
ctx,
mrEnv,
&SqlEngineConfig{
ServerUser: "root",
ServerHost: "localhost",
},
config,
)

return engine, mrEnv.GetFirstDatabase(), err
Expand Down
18 changes: 17 additions & 1 deletion go/cmd/dolt/commands/tblcmds/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import (
"os"
"path/filepath"

"github.com/dolthub/go-mysql-server/sql"
"github.com/fatih/color"

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
Expand Down Expand Up @@ -197,7 +199,21 @@ func (cmd ExportCmd) Exec(ctx context.Context, commandStr string, args []string,
return commands.HandleVErrAndExitCode(verr, usage)
}

rd, err := mvdata.NewSqlEngineReader(ctx, dEnv, exOpts.tableName)
engine, dbName, berr := engine.NewSqlEngineForEnv(ctx, dEnv)
if berr != nil {
return commands.HandleVErrAndExitCode(errhand.VerboseErrorFromError(berr), usage)
}
defer engine.Close()
sqlCtx, berr := engine.NewLocalContext(ctx)
if berr != nil {
return commands.HandleVErrAndExitCode(errhand.VerboseErrorFromError(berr), usage)
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

rd, err := mvdata.NewSqlEngineReader(sqlCtx, engine.GetUnderlyingEngine(), root, exOpts.tableName)
if err != nil {
return commands.HandleVErrAndExitCode(errhand.BuildDError("Error creating reader for %s.", exOpts.SrcName()).AddCause(err).Build(), usage)
}
Expand Down
67 changes: 46 additions & 21 deletions go/cmd/dolt/commands/tblcmds/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"sync/atomic"

sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
Expand All @@ -35,6 +36,7 @@ import (

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
"github.com/dolthub/dolt/go/cmd/dolt/commands/schcmds"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
Expand Down Expand Up @@ -180,7 +182,7 @@ func (m importOptions) srcIsStream() bool {
return isStream
}

func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEnv) (*importOptions, errhand.VerboseError) {
func getImportMoveOptions(ctx *sql.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEnv, engine *sqle.Engine) (*importOptions, errhand.VerboseError) {
tableName := apr.Arg(0)

path := ""
Expand Down Expand Up @@ -225,9 +227,19 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d
// table name must match sheet name currently
srcOpts = mvdata.XlsxOptions{SheetName: tableName}
} else if val.Format == mvdata.JsonFile {
srcOpts = mvdata.JSONOptions{TableName: tableName, SchFile: schemaFile}
opts := mvdata.JSONOptions{TableName: tableName, SchFile: schemaFile}
if schemaFile != "" {
opts.SqlCtx = ctx
opts.Engine = engine
}
srcOpts = opts
} else if val.Format == mvdata.ParquetFile {
srcOpts = mvdata.ParquetOptions{TableName: tableName, SchFile: schemaFile}
opts := mvdata.ParquetOptions{TableName: tableName, SchFile: schemaFile}
if schemaFile != "" {
opts.SqlCtx = ctx
opts.Engine = engine
}
srcOpts = opts
}

case mvdata.StreamDataLocation:
Expand Down Expand Up @@ -417,30 +429,48 @@ func (cmd ImportCmd) Exec(ctx context.Context, commandStr string, args []string,
return commands.HandleVErrAndExitCode(verr, usage)
}

mvOpts, verr := getImportMoveOptions(ctx, apr, dEnv)
if verr != nil {
root, err := dEnv.WorkingRoot(ctx)
if err != nil {
verr = errhand.BuildDError("Unable to get the working root value for this data repository.").AddCause(err).Build()
return commands.HandleVErrAndExitCode(verr, usage)
}

root, err := dEnv.WorkingRoot(ctx)
eng, dbName, err := engine.NewSqlEngineForEnv(ctx, dEnv, func(cfg *engine.SqlEngineConfig) {
cfg.Autocommit = false
cfg.Bulk = true
})
if err != nil {
verr = errhand.BuildDError("Unable to get the working root value for this data repository.").AddCause(err).Build()
verr = errhand.BuildDError("could not build sql engine for import").AddCause(err).Build()
return commands.HandleVErrAndExitCode(verr, usage)
}
sqlCtx, err := eng.NewLocalContext(ctx)
if err != nil {
verr = errhand.BuildDError("could not build sql context for import").AddCause(err).Build()
return commands.HandleVErrAndExitCode(verr, usage)
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
sqlCtx.SetCurrentDatabase(dbName)

mvOpts, verr := getImportMoveOptions(sqlCtx, apr, dEnv, eng.GetUnderlyingEngine())
if verr != nil {
return commands.HandleVErrAndExitCode(verr, usage)
}

rd, nDMErr := newImportDataReader(ctx, root, dEnv, mvOpts)
rd, nDMErr := newImportDataReader(sqlCtx, root, dEnv, mvOpts)
if nDMErr != nil {
verr = newDataMoverErrToVerr(mvOpts, nDMErr)
return commands.HandleVErrAndExitCode(verr, usage)
}

wr, nDMErr := newImportSqlEngineMover(ctx, dEnv, rd.GetSchema(), mvOpts)
wr, nDMErr := newImportSqlEngineMover(sqlCtx, root, dEnv, rd.GetSchema(), eng.GetUnderlyingEngine(), mvOpts)
if nDMErr != nil {
verr = newDataMoverErrToVerr(mvOpts, nDMErr)
return commands.HandleVErrAndExitCode(verr, usage)
}

skipped, err := move(ctx, rd, wr, mvOpts)
skipped, err := move(sqlCtx, rd, wr, mvOpts)
if err != nil {
bdr := errhand.BuildDError("\nAn error occurred while moving data")
bdr.AddCause(err)
Expand Down Expand Up @@ -488,11 +518,11 @@ func newImportDataReader(ctx context.Context, root doltdb.RootValue, dEnv *env.D
return rd, nil
}

func newImportSqlEngineMover(ctx context.Context, dEnv *env.DoltEnv, rdSchema schema.Schema, imOpts *importOptions) (*mvdata.SqlEngineTableWriter, *mvdata.DataMoverCreationError) {
func newImportSqlEngineMover(ctx *sql.Context, root doltdb.RootValue, dEnv *env.DoltEnv, rdSchema schema.Schema, engine *sqle.Engine, imOpts *importOptions) (*mvdata.SqlEngineTableWriter, *mvdata.DataMoverCreationError) {
moveOps := &mvdata.MoverOptions{Force: imOpts.force, TableToWriteTo: imOpts.destTableName, ContinueOnErr: imOpts.contOnErr, Operation: imOpts.operation, DisableFks: imOpts.disableFkChecks}

// Returns the schema of the table to be created or the existing schema
tableSchema, dmce := getImportSchema(ctx, dEnv, imOpts)
tableSchema, dmce := getImportSchema(ctx, root, dEnv, engine, imOpts)
if dmce != nil {
return nil, dmce
}
Expand Down Expand Up @@ -537,7 +567,7 @@ func newImportSqlEngineMover(ctx context.Context, dEnv *env.DoltEnv, rdSchema sc
}
}

mv, err := mvdata.NewSqlEngineTableWriter(ctx, dEnv, tableSchema, rowOperationSchema, moveOps, importStatsCB)
mv, err := mvdata.NewSqlEngineTableWriter(ctx, engine, tableSchema, rowOperationSchema, moveOps, importStatsCB)
if err != nil {
return nil, &mvdata.DataMoverCreationError{ErrType: mvdata.CreateWriterErr, Cause: err}
}
Expand Down Expand Up @@ -687,9 +717,9 @@ func moveRows(
}
}

func getImportSchema(ctx context.Context, dEnv *env.DoltEnv, impOpts *importOptions) (schema.Schema, *mvdata.DataMoverCreationError) {
func getImportSchema(ctx *sql.Context, root doltdb.RootValue, dEnv *env.DoltEnv, engine *sqle.Engine, impOpts *importOptions) (schema.Schema, *mvdata.DataMoverCreationError) {
if impOpts.schFile != "" {
tn, out, err := mvdata.SchAndTableNameFromFile(ctx, impOpts.schFile, dEnv)
tn, out, err := mvdata.SchAndTableNameFromFile(ctx, impOpts.schFile, dEnv.FS, root, engine)
if err != nil {
return nil, &mvdata.DataMoverCreationError{ErrType: mvdata.SchemaErr, Cause: err}
}
Expand Down Expand Up @@ -728,11 +758,6 @@ func getImportSchema(ctx context.Context, dEnv *env.DoltEnv, impOpts *importOpti
return rd.GetSchema(), nil
}

root, err := dEnv.WorkingRoot(ctx)
if err != nil {
return nil, &mvdata.DataMoverCreationError{ErrType: mvdata.SchemaErr, Cause: err}
}

outSch, err := mvdata.InferSchema(ctx, root, rd, impOpts.destTableName, impOpts.primaryKeys, impOpts)
if err != nil {
return nil, &mvdata.DataMoverCreationError{ErrType: mvdata.SchemaErr, Cause: err}
Expand All @@ -742,7 +767,7 @@ func getImportSchema(ctx context.Context, dEnv *env.DoltEnv, impOpts *importOpti
}

// UpdateOp || ReplaceOp
tblRd, err := mvdata.NewSqlEngineReader(ctx, dEnv, impOpts.destTableName)
tblRd, err := mvdata.NewSqlEngineReader(ctx, engine, root, impOpts.destTableName)
if err != nil {
return nil, &mvdata.DataMoverCreationError{ErrType: mvdata.CreateReaderErr, Cause: err}
}
Expand Down
Loading
Loading