diff --git a/go/cmd/dolt/cli/query_helpers.go b/go/cmd/dolt/cli/query_helpers.go new file mode 100644 index 00000000000..2bcc213f8ba --- /dev/null +++ b/go/cmd/dolt/cli/query_helpers.go @@ -0,0 +1,79 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" +) + +// GetInt8ColAsBool returns the value of an int8 column as a bool +// This is necessary because Queryist may return an int8 column as a bool (when using SQLEngine) +// or as a string (when using ConnectionQueryist). +func GetInt8ColAsBool(col interface{}) (bool, error) { + switch v := col.(type) { + case int8: + return v != 0, nil + case string: + return v != "0", nil + default: + return false, fmt.Errorf("unexpected type %T, was expecting int8", v) + } +} + +// SetSystemVar sets the @@dolt_show_system_tables variable if necessary, and returns a function +// resetting the variable for after the commands completion, if necessary. +func SetSystemVar(queryist Queryist, sqlCtx *sql.Context, newVal bool) (func() error, error) { + _, rowIter, _, err := queryist.Query(sqlCtx, "SHOW VARIABLES WHERE VARIABLE_NAME='dolt_show_system_tables'") + if err != nil { + return nil, err + } + + row, err := sql.RowIterToRows(sqlCtx, rowIter) + if err != nil { + return nil, err + } + prevVal, err := GetInt8ColAsBool(row[0][1]) + if err != nil { + return nil, err + } + + var update func() error + if newVal != prevVal { + query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", newVal) + _, _, _, err = queryist.Query(sqlCtx, query) + update = func() error { + query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", prevVal) + _, _, _, err := queryist.Query(sqlCtx, query) + return err + } + } + + return update, err +} + +func GetRowsForSql(queryist Queryist, sqlCtx *sql.Context, query string) ([]sql.Row, error) { + _, rowIter, _, err := queryist.Query(sqlCtx, query) + if err != nil { + return nil, err + } + rows, err := sql.RowIterToRows(sqlCtx, rowIter) + if err != nil { + return nil, err + } + + return rows, nil +} diff --git a/go/cmd/dolt/commands/checkout.go b/go/cmd/dolt/commands/checkout.go index 00fbc02ca6c..fdc632158b9 100644 --- a/go/cmd/dolt/commands/checkout.go +++ b/go/cmd/dolt/commands/checkout.go @@ -145,7 +145,7 @@ func (cmd CheckoutCmd) Exec(ctx context.Context, commandStr string, args []strin return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - rows, err := GetRowsForSql(queryEngine, sqlCtx, sqlQuery) + rows, err := cli.GetRowsForSql(queryEngine, sqlCtx, sqlQuery) if err != nil { // In fringe cases the server can't start because the default branch doesn't exist, `dolt checkout ` diff --git a/go/cmd/dolt/commands/cherry-pick.go b/go/cmd/dolt/commands/cherry-pick.go index a18c29fd044..d5c8ae0c12f 100644 --- a/go/cmd/dolt/commands/cherry-pick.go +++ b/go/cmd/dolt/commands/cherry-pick.go @@ -142,12 +142,12 @@ func cherryPick(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgPa hint: commit your changes (dolt commit -am \"\") or reset them (dolt reset --hard) to proceed.`) } - _, err = GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts = 1") + _, err = cli.GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts = 1") if err != nil { return fmt.Errorf("error: failed to set @@dolt_allow_commit_conflicts: %w", err) } - _, err = GetRowsForSql(queryist, sqlCtx, "set @@dolt_force_transaction_commit = 1") + _, err = cli.GetRowsForSql(queryist, sqlCtx, "set @@dolt_force_transaction_commit = 1") if err != nil { return fmt.Errorf("error: failed to set @@dolt_force_transaction_commit: %w", err) } @@ -156,7 +156,7 @@ hint: commit your changes (dolt commit -am \"\") or reset them (dolt re if err != nil { return fmt.Errorf("error: failed to interpolate query: %w", err) } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { errorText := err.Error() switch { @@ -225,7 +225,7 @@ hint: commit your changes (dolt commit -am \"\") or reset them (dolt re func cherryPickAbort(queryist cli.Queryist, sqlCtx *sql.Context) error { query := "call dolt_cherry_pick('--abort')" - _, err := GetRowsForSql(queryist, sqlCtx, query) + _, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { errorText := err.Error() switch errorText { diff --git a/go/cmd/dolt/commands/ci/run.go b/go/cmd/dolt/commands/ci/run.go index eb3a605ad14..3701b45e0d9 100644 --- a/go/cmd/dolt/commands/ci/run.go +++ b/go/cmd/dolt/commands/ci/run.go @@ -148,7 +148,7 @@ func runCIQuery(queryist cli.Queryist, sqlCtx *sql.Context, step dolt_ci.Step, q return nil, fmt.Errorf("Could not find saved query: %s", step.SavedQueryName.Value) } - rows, err := commands.GetRowsForSql(queryist, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { statementErr := fmt.Sprintf("Ran query: %s", query) queryErr := fmt.Sprintf("Query error: %s", err.Error()) diff --git a/go/cmd/dolt/commands/ci/view.go b/go/cmd/dolt/commands/ci/view.go index 128e8634c75..d7147f37044 100644 --- a/go/cmd/dolt/commands/ci/view.go +++ b/go/cmd/dolt/commands/ci/view.go @@ -171,7 +171,7 @@ func updateConfigQueryStatements(config *dolt_ci.WorkflowConfig, savedQueries ma func getSavedQueries(sqlCtx *sql.Context, queryist cli.Queryist) (map[string]string, error) { savedQueries := make(map[string]string) - resetFunc, err := commands.SetSystemVar(queryist, sqlCtx, true) + resetFunc, err := cli.SetSystemVar(queryist, sqlCtx, true) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/clean.go b/go/cmd/dolt/commands/clean.go index 9cd0a29f1e3..726c7ef9172 100644 --- a/go/cmd/dolt/commands/clean.go +++ b/go/cmd/dolt/commands/clean.go @@ -116,7 +116,7 @@ func (cmd CleanCmd) Exec(ctx context.Context, commandStr string, args []string, } } - _, err = GetRowsForSql(queryist, sqlCtx, query) + _, err = cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { cli.Println(err.Error()) return 1 diff --git a/go/cmd/dolt/commands/cnfcmds/auto_resolve.go b/go/cmd/dolt/commands/cnfcmds/auto_resolve.go index c45970c71fc..60d530b6af8 100644 --- a/go/cmd/dolt/commands/cnfcmds/auto_resolve.go +++ b/go/cmd/dolt/commands/cnfcmds/auto_resolve.go @@ -23,7 +23,6 @@ import ( "github.com/gocraft/dbr/v2/dialect" "github.com/dolthub/dolt/go/cmd/dolt/cli" - "github.com/dolthub/dolt/go/cmd/dolt/commands" ) type AutoResolveStrategy int @@ -53,7 +52,7 @@ func AutoResolveTables(queryist cli.Queryist, sqlCtx *sql.Context, strategy Auto if err != nil { return fmt.Errorf("error interpolating resolve conflicts query for table %s: %w", tableName, err) } - _, err = commands.GetRowsForSql(queryist, sqlCtx, q) + _, err = cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return fmt.Errorf("error resolving conflicts for table %s: %w", tableName, err) } diff --git a/go/cmd/dolt/commands/cnfcmds/cat.go b/go/cmd/dolt/commands/cnfcmds/cat.go index 0f5a6b5754b..bcbb446dc6a 100644 --- a/go/cmd/dolt/commands/cnfcmds/cat.go +++ b/go/cmd/dolt/commands/cnfcmds/cat.go @@ -288,7 +288,7 @@ func writeConflictResults( func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (mergeStatus, error) { ms := mergeStatus{} q := "select * from dolt_merge_status;" - rows, err := commands.GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return ms, err } @@ -315,7 +315,7 @@ func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (mergeStatus, er func getSchemaConflictsExist(queryist cli.Queryist, sqlCtx *sql.Context) (bool, error) { q := "select * from dolt_schema_conflicts limit 1;" - rows, err := commands.GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return false, err } @@ -332,7 +332,7 @@ func getTableDataConflictsExist(queryist cli.Queryist, sqlCtx *sql.Context, tabl if err != nil { return false, err } - rows, err := commands.GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return false, err } diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index abd1d26d470..a1e2476f69f 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -208,7 +208,7 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, _ defer closeFunc() } - updateSystemVar, err := SetSystemVar(queryist, sqlCtx, apr.Contains(cli.SystemFlag)) + updateSystemVar, err := cli.SetSystemVar(queryist, sqlCtx, apr.Contains(cli.SystemFlag)) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -382,7 +382,7 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string) if err != nil { return nil, fmt.Errorf("error interpolating query: %w", err) } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, err } @@ -415,7 +415,7 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string) if err != nil { return nil, fmt.Errorf("error interpolating query: %w", err) } - result, err := GetRowsForSql(queryist, sqlCtx, interpolatedQuery) + result, err := cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery) if err != nil { return nil, fmt.Errorf("error getting system table %s: %w", sysTable, err) } @@ -546,7 +546,7 @@ func getCommonAncestor(queryist cli.Queryist, sqlCtx *sql.Context, c1, c2 string if err != nil { return "", fmt.Errorf("error interpolating query: %w", err) } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return "", err } @@ -675,7 +675,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex if err != nil { return nil, fmt.Errorf("error: unable to interpolate query: %w", err) } - schemaDiffRows, err := GetRowsForSql(queryist, sqlCtx, q) + schemaDiffRows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, fmt.Errorf("error: unable to get schema diff from %s to %s: %w", fromRef, toRef, err) } @@ -714,7 +714,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex if err != nil { return nil, fmt.Errorf("error: unable to interpolate dolt_patch query: %w", err) } - patchRows, err := GetRowsForSql(queryist, sqlCtx, q) + patchRows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, fmt.Errorf("error: unable to get dolt_patch rows from %s to %s: %w", fromRef, toRef, err) } @@ -741,7 +741,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex func getDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Context, fromRef, toRef string) ([]diff.TableDeltaSummary, error) { q, err := dbr.InterpolateForDialect("select * from dolt_diff_summary(?, ?)", []interface{}{fromRef, toRef}, dialect.MySQL) - dataDiffRows, err := GetRowsForSql(queryist, sqlCtx, q) + dataDiffRows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, fmt.Errorf("error: unable to get diff summary from %s to %s: %w", fromRef, toRef, err) } @@ -907,7 +907,7 @@ func getTableSchemaAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableName s if err != nil { return sch, createStmt, fmt.Errorf("error interpolating query: %w", err) } - rows, err = GetRowsForSql(queryist, sqlCtx, interpolatedQuery) + rows, err = cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery) if err != nil { return sch, createStmt, fmt.Errorf("error: unable to get create table statement for table '%s': %w", tableName, err) } @@ -952,7 +952,7 @@ func getDatabaseSchemaAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableNam if err != nil { return "", fmt.Errorf("error interpolating query: %w", err) } - rows, err = GetRowsForSql(queryist, sqlCtx, interpolatedQuery) + rows, err = cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery) if err != nil { return "", fmt.Errorf("error: unable to get create database statement for database '%s': %w", tableName, err) } @@ -1047,7 +1047,7 @@ func getTableDiffStats(queryist cli.Queryist, sqlCtx *sql.Context, tableName, fr if err != nil { return nil, fmt.Errorf("error interpolating query: %w", err) } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, fmt.Errorf("error running diff stats query: %w", err) } diff --git a/go/cmd/dolt/commands/log.go b/go/cmd/dolt/commands/log.go index 673364e4444..a918b9acc85 100644 --- a/go/cmd/dolt/commands/log.go +++ b/go/cmd/dolt/commands/log.go @@ -117,7 +117,7 @@ func (cmd LogCmd) logWithLoggerFunc(ctx context.Context, commandStr string, args if err != nil { return handleErrAndExit(err) } - logRows, err := GetRowsForSql(queryist, sqlCtx, query) + logRows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return handleErrAndExit(err) } @@ -140,7 +140,7 @@ func collectRevisions(apr *argparser.ArgParseResults, queryist cli.Queryist, sql tablesIndex++ revisions[arg] = true } else { - _, err := GetRowsForSql(queryist, sqlCtx, "select hashof('"+arg+"')") + _, err := cli.GetRowsForSql(queryist, sqlCtx, "select hashof('"+arg+"')") // Once we get a non-revision argument, we treat the remaining args as tables if _, ok := revisions[arg]; ok || err != nil { @@ -288,7 +288,7 @@ func getExistingTables(revisions []string, queryist cli.Queryist, sqlCtx *sql.Co } for _, rev := range revisions { - rows, err := GetRowsForSql(queryist, sqlCtx, "show tables as of '"+rev+"'") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "show tables as of '"+rev+"'") if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/ls.go b/go/cmd/dolt/commands/ls.go index ffbb7ba097f..3fd0a9b599e 100644 --- a/go/cmd/dolt/commands/ls.go +++ b/go/cmd/dolt/commands/ls.go @@ -102,7 +102,7 @@ func (cmd LsCmd) Exec(ctx context.Context, commandStr string, args []string, dEn if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - rows, err := GetRowsForSql(queryist, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -168,7 +168,7 @@ func printUserTables(tableNames []string, apr *argparser.ArgParseResults, queryi label = "working set" } else { query := fmt.Sprintf("select hashof('%s')", apr.Arg(0)) - row, err := GetRowsForSql(queryist, sqlCtx, query) + row, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return err } @@ -197,7 +197,7 @@ func printUserTables(tableNames []string, apr *argparser.ArgParseResults, queryi func printTableVerbose(table string, queryist cli.Queryist, sqlCtx *sql.Context) error { query := fmt.Sprintf("select count(*) from `%s`", table) - row, err := GetRowsForSql(queryist, sqlCtx, query) + row, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return err } diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index a93c68265fa..a788504ea76 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -410,7 +410,7 @@ func printMergeStats(fastForward bool, // calculateMergeConflicts calculates the count of conflicts that occurred during the merge. Returns a map of table name to MergeStats, // a bool indicating whether there were any conflicts, and a bool indicating whether calculation was successful. func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeStats map[string]*merge.MergeStats) (map[string]*merge.MergeStats, bool, error) { - dataConflicts, err := GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_conflicts FROM dolt_conflicts") + dataConflicts, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_conflicts FROM dolt_conflicts") if err != nil { return nil, false, err } @@ -429,7 +429,7 @@ func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeSt } } - schemaConflicts, err := GetRowsForSql(queryist, sqlCtx, "SELECT table_name FROM dolt_schema_conflicts") + schemaConflicts, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT table_name FROM dolt_schema_conflicts") if err != nil { return nil, false, err } @@ -442,7 +442,7 @@ func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeSt } } - constraintViolations, err := GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_violations FROM dolt_constraint_violations") + constraintViolations, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_violations FROM dolt_constraint_violations") if err != nil { return nil, false, err } @@ -674,7 +674,7 @@ func fillStringWithChar(ch rune, strLen int) string { } func handleMergeErr(sqlCtx *sql.Context, queryist cli.Queryist, mergeErr error, hasConflicts, hasConstraintViolations bool, usage cli.UsagePrinter) int { - unmergedTables, err := GetRowsForSql(queryist, sqlCtx, "select unmerged_tables from dolt_merge_status") + unmergedTables, err := cli.GetRowsForSql(queryist, sqlCtx, "select unmerged_tables from dolt_merge_status") if err != nil { cli.PrintErrln(err.Error()) return 1 diff --git a/go/cmd/dolt/commands/merge_base.go b/go/cmd/dolt/commands/merge_base.go index 158c19ab1c0..515f5273b94 100644 --- a/go/cmd/dolt/commands/merge_base.go +++ b/go/cmd/dolt/commands/merge_base.go @@ -87,7 +87,7 @@ func (cmd MergeBaseCmd) Exec(ctx context.Context, commandStr string, args []stri return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - row, err := GetRowsForSql(queryist, sqlCtx, interpolatedQuery) + row, err := cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } diff --git a/go/cmd/dolt/commands/pull.go b/go/cmd/dolt/commands/pull.go index b3fedbf5c80..3970bcdfb31 100644 --- a/go/cmd/dolt/commands/pull.go +++ b/go/cmd/dolt/commands/pull.go @@ -301,7 +301,7 @@ func getRemoteHashForPull(apr *argparser.ArgParseResults, sqlCtx *sql.Context, q remote = apr.Args[0] } - rows, err := GetRowsForSql(queryist, sqlCtx, "select name from dolt_remote_branches") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "select name from dolt_remote_branches") if err != nil { return "", "", err } @@ -334,7 +334,7 @@ func getRemoteHashForPull(apr *argparser.ArgParseResults, sqlCtx *sql.Context, q // getDefaultRemote gets the name of the default remote. func getDefaultRemote(sqlCtx *sql.Context, queryist cli.Queryist) (string, error) { - rows, err := GetRowsForSql(queryist, sqlCtx, "select name from dolt_remotes") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "select name from dolt_remotes") if err != nil { return "", err } diff --git a/go/cmd/dolt/commands/rebase.go b/go/cmd/dolt/commands/rebase.go index 52c1bd4df58..f25afd6434b 100644 --- a/go/cmd/dolt/commands/rebase.go +++ b/go/cmd/dolt/commands/rebase.go @@ -98,7 +98,7 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, // Set @@dolt_allow_commit_conflicts in case there are data conflicts that need to be resolved by the caller. // Without this, the conflicts can't be committed to the branch working set, and the caller can't access them. - if _, err = GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts=1;"); err != nil { + if _, err = cli.GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts=1;"); err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -112,7 +112,7 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - rows, err := GetRowsForSql(queryist, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -146,7 +146,7 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, // if all uncommented lines are deleted in the editor, abort the rebase if rebasePlan == nil || rebasePlan.Steps == nil || len(rebasePlan.Steps) == 0 { - rows, err := GetRowsForSql(queryist, sqlCtx, "CALL DOLT_REBASE('--abort');") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "CALL DOLT_REBASE('--abort');") if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -169,7 +169,7 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - rows, err = GetRowsForSql(queryist, sqlCtx, "CALL DOLT_REBASE('--continue');") + rows, err = cli.GetRowsForSql(queryist, sqlCtx, "CALL DOLT_REBASE('--continue');") if err != nil { // If the error is a data conflict, don't abort the rebase, but let the caller resolve the conflicts if dprocedures.ErrRebaseDataConflict.Is(err) || strings.Contains(err.Error(), dprocedures.ErrRebaseDataConflict.Message[:40]) { @@ -244,7 +244,7 @@ func getRebasePlan(cliCtx cli.CliContext, sqlCtx *sql.Context, queryist cli.Quer func buildInitialRebaseMsg(sqlCtx *sql.Context, queryist cli.Queryist, rebaseBranch, currentBranch string) (string, error) { var buffer bytes.Buffer - rows, err := GetRowsForSql(queryist, sqlCtx, "SELECT action, commit_hash, commit_message FROM dolt_rebase ORDER BY rebase_order") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT action, commit_hash, commit_message FROM dolt_rebase ORDER BY rebase_order") if err != nil { return "", err } @@ -337,13 +337,13 @@ func parseRebaseMessage(rebaseMsg string) (*rebase.RebasePlan, error) { // insertRebasePlanIntoDoltRebaseTable inserts the rebase plan into the dolt_rebase table by re-building the dolt_rebase // table from scratch. func insertRebasePlanIntoDoltRebaseTable(plan *rebase.RebasePlan, sqlCtx *sql.Context, queryist cli.Queryist) error { - _, err := GetRowsForSql(queryist, sqlCtx, "TRUNCATE TABLE dolt_rebase") + _, err := cli.GetRowsForSql(queryist, sqlCtx, "TRUNCATE TABLE dolt_rebase") if err != nil { return err } for i, step := range plan.Steps { - _, err := GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("INSERT INTO dolt_rebase VALUES (%d, '%s', '%s', '%s')", i+1, step.Action, step.CommitHash, step.CommitMsg)) + _, err := cli.GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("INSERT INTO dolt_rebase VALUES (%d, '%s', '%s', '%s')", i+1, step.Action, step.CommitHash, step.CommitMsg)) if err != nil { return err } diff --git a/go/cmd/dolt/commands/reflog.go b/go/cmd/dolt/commands/reflog.go index 34c117f2081..368778ea04f 100644 --- a/go/cmd/dolt/commands/reflog.go +++ b/go/cmd/dolt/commands/reflog.go @@ -95,7 +95,7 @@ func (cmd ReflogCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - rows, err := GetRowsForSql(queryist, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -138,7 +138,7 @@ func printReflog(rows []sql.Row, queryist cli.Queryist, sqlCtx *sql.Context) int // Get the hash of HEAD for the `HEAD ->` decoration headHash := "" - res, err := GetRowsForSql(queryist, sqlCtx, "SELECT hashof('HEAD')") + res, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT hashof('HEAD')") if err == nil { // still print the reflog even if we can't get the hash headHash = res[0][0].(string) diff --git a/go/cmd/dolt/commands/remote.go b/go/cmd/dolt/commands/remote.go index dadd5c9fc6e..a7d5c1fe279 100644 --- a/go/cmd/dolt/commands/remote.go +++ b/go/cmd/dolt/commands/remote.go @@ -232,7 +232,7 @@ func callSQLRemoteAdd(sqlCtx *sql.Context, queryist cli.Queryist, remoteName, re return err } - _, err = GetRowsForSql(queryist, sqlCtx, qry) + _, err = cli.GetRowsForSql(queryist, sqlCtx, qry) return err } @@ -243,7 +243,7 @@ func callSQLRemoteRemove(sqlCtxe *sql.Context, queryist cli.Queryist, remoteName return err } - _, err = GetRowsForSql(queryist, sqlCtxe, qry) + _, err = cli.GetRowsForSql(queryist, sqlCtxe, qry) return err } @@ -255,7 +255,7 @@ type remote struct { func getRemotesSQL(sqlCtx *sql.Context, queryist cli.Queryist) ([]remote, error) { qry := "select name,url,params from dolt_remotes" - rows, err := GetRowsForSql(queryist, sqlCtx, qry) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, qry) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index d93774f4438..921483d7a01 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -364,7 +364,7 @@ func executeSavedQuery(ctx *sql.Context, qryist cli.Queryist, savedQueryName str buffer.WriteString("SELECT query FROM dolt_query_catalog where id = ?") searchQuery, err := dbr.InterpolateForDialect(buffer.String(), []interface{}{savedQueryName}, dialect.MySQL) - rows, err := GetRowsForSql(qryist, ctx, searchQuery) + rows, err := cli.GetRowsForSql(qryist, ctx, searchQuery) if err != nil { return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage) } else if len(rows) == 0 { @@ -417,7 +417,7 @@ func SaveQuery(ctx *sql.Context, qryist cli.Queryist, apr *argparser.ArgParseRes } order := int32(1) - rows, err := GetRowsForSql(qryist, ctx, "SELECT MAX(display_order) FROM dolt_query_catalog") + rows, err := cli.GetRowsForSql(qryist, ctx, "SELECT MAX(display_order) FROM dolt_query_catalog") if err != nil { return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage) } @@ -438,7 +438,7 @@ func SaveQuery(ctx *sql.Context, qryist cli.Queryist, apr *argparser.ArgParseRes return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage) } - _, err = GetRowsForSql(qryist, ctx, insertQuery) + _, err = cli.GetRowsForSql(qryist, ctx, insertQuery) return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage) } diff --git a/go/cmd/dolt/commands/stash.go b/go/cmd/dolt/commands/stash.go index ff470408819..de9fab1871e 100644 --- a/go/cmd/dolt/commands/stash.go +++ b/go/cmd/dolt/commands/stash.go @@ -229,7 +229,7 @@ func getStashesSQL(sqlCtx *sql.Context, queryist cli.Queryist, limit int) ([]*do } qry := fmt.Sprintf("select stash_id, branch, hash, commit_message from dolt_stashes where name = '%s' order by stash_id ASC %s;", doltdb.DoltCliRef, limitStr) - rows, err := GetRowsForSql(queryist, sqlCtx, qry) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, qry) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/status.go b/go/cmd/dolt/commands/status.go index 6604bdada2c..a37cbf5178a 100644 --- a/go/cmd/dolt/commands/status.go +++ b/go/cmd/dolt/commands/status.go @@ -183,7 +183,7 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show return nil, err } - statusRows, err := GetRowsForSql(queryist, sqlCtx, "select table_name,staged,status from dolt_status;") + statusRows, err := cli.GetRowsForSql(queryist, sqlCtx, "select table_name,staged,status from dolt_status;") if err != nil { return nil, err } @@ -348,7 +348,7 @@ func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName stri q = fmt.Sprintf("select name, hash from dolt_branches where name = '%s'", upstreamBranchName) } - upstreamBranches, err := GetRowsForSql(queryist, sqlCtx, q) + upstreamBranches, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return ahead, behind, err } @@ -360,7 +360,7 @@ func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName stri upstreamBranchCommit := upstreamBranches[0][1].(string) q = fmt.Sprintf("call dolt_count_commits('--from', '%s', '--to', '%s')", currentBranchCommit, upstreamBranchCommit) - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return ahead, behind, err } @@ -387,7 +387,7 @@ func getLocalBranchInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName s currentBranchCommit = "" remoteBranchName = "" - localBranches, err := GetRowsForSql(queryist, sqlCtx, "select name, hash, remote, branch from dolt_branches;") + localBranches, err := cli.GetRowsForSql(queryist, sqlCtx, "select name, hash, remote, branch from dolt_branches;") if err != nil { return remoteName, remoteBranchName, currentBranchCommit, err } @@ -406,7 +406,7 @@ func getLocalBranchInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName s } func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (bool, error) { - mergeRows, err := GetRowsForSql(queryist, sqlCtx, "select is_merging from dolt_merge_status;") + mergeRows, err := cli.GetRowsForSql(queryist, sqlCtx, "select is_merging from dolt_merge_status;") if err != nil { return false, err } @@ -426,7 +426,7 @@ func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (bool, error) { func getDataConflictsTables(queryist cli.Queryist, sqlCtx *sql.Context) (map[string]bool, error) { dataConflictTables := make(map[string]bool) - dataConflicts, err := GetRowsForSql(queryist, sqlCtx, "select * from dolt_conflicts;") + dataConflicts, err := cli.GetRowsForSql(queryist, sqlCtx, "select * from dolt_conflicts;") if err != nil { return nil, err } @@ -439,7 +439,7 @@ func getDataConflictsTables(queryist cli.Queryist, sqlCtx *sql.Context) (map[str func getConstraintViolationTables(queryist cli.Queryist, sqlCtx *sql.Context) (map[string]bool, error) { constraintViolationTables := make(map[string]bool) - constraintViolations, err := GetRowsForSql(queryist, sqlCtx, "select * from dolt_constraint_violations;") + constraintViolations, err := cli.GetRowsForSql(queryist, sqlCtx, "select * from dolt_constraint_violations;") if err != nil { return nil, err } @@ -453,7 +453,7 @@ func getConstraintViolationTables(queryist cli.Queryist, sqlCtx *sql.Context) (m func getWorkingStagedTables(queryist cli.Queryist, sqlCtx *sql.Context) (map[string]bool, map[string]bool, error) { stagedTableNames := make(map[string]bool) workingTableNames := make(map[string]bool) - diffs, err := GetRowsForSql(queryist, sqlCtx, "select * from dolt_diff where commit_hash='WORKING' OR commit_hash='STAGED';") + diffs, err := cli.GetRowsForSql(queryist, sqlCtx, "select * from dolt_diff where commit_hash='WORKING' OR commit_hash='STAGED';") if err != nil { return nil, nil, err } @@ -471,7 +471,7 @@ func getWorkingStagedTables(queryist cli.Queryist, sqlCtx *sql.Context) (map[str func getIgnoredTablePatternsFromSql(queryist cli.Queryist, sqlCtx *sql.Context) (doltdb.IgnorePatterns, error) { var ignorePatterns []doltdb.IgnorePattern - ignoreRows, err := GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("select * from %s", doltdb.IgnoreTableName)) + ignoreRows, err := cli.GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("select * from %s", doltdb.IgnoreTableName)) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/tag.go b/go/cmd/dolt/commands/tag.go index a2d9bad34aa..1beff0ae8af 100644 --- a/go/cmd/dolt/commands/tag.go +++ b/go/cmd/dolt/commands/tag.go @@ -209,7 +209,7 @@ func verboseTagPrint(tag tagInfo) { } func getTagInfos(queryist cli.Queryist, sqlCtx *sql.Context) ([]tagInfo, error) { - rows, err := GetRowsForSql(queryist, sqlCtx, "SELECT * FROM dolt_tags") + rows, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT * FROM dolt_tags") if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/utils.go b/go/cmd/dolt/commands/utils.go index 5bd34a3a621..811972e6c34 100644 --- a/go/cmd/dolt/commands/utils.go +++ b/go/cmd/dolt/commands/utils.go @@ -324,19 +324,6 @@ func newLateBindingEngine( return lateBinder, nil } -func GetRowsForSql(queryist cli.Queryist, sqlCtx *sql.Context, query string) ([]sql.Row, error) { - _, rowIter, _, err := queryist.Query(sqlCtx, query) - if err != nil { - return nil, err - } - rows, err := sql.RowIterToRows(sqlCtx, rowIter) - if err != nil { - return nil, err - } - - return rows, nil -} - // InterpolateAndRunQuery interpolates a query, executes it, and returns the result rows. // Since this method does not return a schema, this method should be used only for fire-and-forget types of queries. func InterpolateAndRunQuery(queryist cli.Queryist, sqlCtx *sql.Context, queryTemplate string, params ...interface{}) ([]sql.Row, error) { @@ -344,7 +331,7 @@ func InterpolateAndRunQuery(queryist cli.Queryist, sqlCtx *sql.Context, queryTem if err != nil { return nil, fmt.Errorf("error interpolating query: %w", err) } - return GetRowsForSql(queryist, sqlCtx, query) + return cli.GetRowsForSql(queryist, sqlCtx, query) } // GetTinyIntColAsBool returns the value of a tinyint column as a bool @@ -363,20 +350,6 @@ func GetTinyIntColAsBool(col interface{}) (bool, error) { } } -// GetInt8ColAsBool returns the value of an int8 column as a bool -// This is necessary because Queryist may return an int8 column as a bool (when using SQLEngine) -// or as a string (when using ConnectionQueryist). -func GetInt8ColAsBool(col interface{}) (bool, error) { - switch v := col.(type) { - case int8: - return v != 0, nil - case string: - return v != "0", nil - default: - return false, fmt.Errorf("unexpected type %T, was expecting int8", v) - } -} - // getInt64ColAsInt64 returns the value of an int64 column as an int64 // This is necessary because Queryist may return an int64 column as an int64 (when using SQLEngine) // or as a string (when using ConnectionQueryist). @@ -458,7 +431,7 @@ func getStrBoolColAsBool(col interface{}) (bool, error) { func getActiveBranchName(sqlCtx *sql.Context, queryEngine cli.Queryist) (string, error) { query := "SELECT active_branch()" - rows, err := GetRowsForSql(queryEngine, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryEngine, sqlCtx, query) if err != nil { return "", err } @@ -571,7 +544,7 @@ func GetDoltStatus(queryist cli.Queryist, sqlCtx *sql.Context) (stagedChangedTab } var statusRows []sql.Row - statusRows, err = GetRowsForSql(queryist, sqlCtx, "select table_name,staged from dolt_status;") + statusRows, err = cli.GetRowsForSql(queryist, sqlCtx, "select table_name,staged from dolt_status;") if err != nil { return stagedChangedTables, unstagedChangedTables, fmt.Errorf("error: failed to get dolt status: %w", err) } @@ -729,7 +702,7 @@ func getCommitInfoWithOptions(queryist cli.Queryist, sqlCtx *sql.Context, ref st } } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, fmt.Errorf("error getting logs for ref '%s': %v", ref, err) } @@ -821,7 +794,7 @@ func getBranchesForHash(queryist cli.Queryist, sqlCtx *sql.Context, targetHash s if err != nil { return nil, err } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, err } @@ -839,7 +812,7 @@ func getTagsForHash(queryist cli.Queryist, sqlCtx *sql.Context, targetHash strin if err != nil { return nil, err } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return nil, err } @@ -873,7 +846,7 @@ func getHashOf(queryist cli.Queryist, sqlCtx *sql.Context, ref string) (string, if err != nil { return "", fmt.Errorf("error interpolating hashof query: %v", err) } - rows, err := GetRowsForSql(queryist, sqlCtx, q) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, q) if err != nil { return "", fmt.Errorf("error getting hash of ref '%s': %v", ref, err) } @@ -1037,34 +1010,3 @@ func execEditor(initialMsg string, suffix string, cliCtx cli.CliContext) (edited return editedMsg, nil } - -// SetSystemVar sets the @@dolt_show_system_tables variable if necessary, and returns a function -// resetting the variable for after the commands completion, if necessary. -func SetSystemVar(queryist cli.Queryist, sqlCtx *sql.Context, newVal bool) (func() error, error) { - _, rowIter, _, err := queryist.Query(sqlCtx, "SHOW VARIABLES WHERE VARIABLE_NAME='dolt_show_system_tables'") - if err != nil { - return nil, err - } - - row, err := sql.RowIterToRows(sqlCtx, rowIter) - if err != nil { - return nil, err - } - prevVal, err := GetInt8ColAsBool(row[0][1]) - if err != nil { - return nil, err - } - - var update func() error - if newVal != prevVal { - query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", newVal) - _, _, _, err = queryist.Query(sqlCtx, query) - update = func() error { - query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", prevVal) - _, _, _, err := queryist.Query(sqlCtx, query) - return err - } - } - - return update, err -} diff --git a/go/libraries/doltcore/env/actions/dolt_ci/schema.go b/go/libraries/doltcore/env/actions/dolt_ci/schema.go index 47f8e87a8d0..b37abfd2be9 100644 --- a/go/libraries/doltcore/env/actions/dolt_ci/schema.go +++ b/go/libraries/doltcore/env/actions/dolt_ci/schema.go @@ -20,7 +20,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/cmd/dolt/cli" - "github.com/dolthub/dolt/go/cmd/dolt/commands" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" ) @@ -69,12 +68,12 @@ func HasDoltCITables(queryist cli.Queryist, sqlCtx *sql.Context) (bool, error) { for _, tableName := range tableNames { query := fmt.Sprintf("SHOW TABLES LIKE '%s';", tableName.Name) - resetFunc, err := commands.SetSystemVar(queryist, sqlCtx, true) + resetFunc, err := cli.SetSystemVar(queryist, sqlCtx, true) if err != nil { return false, err } - rows, err := commands.GetRowsForSql(queryist, sqlCtx, query) + rows, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return false, err } @@ -108,14 +107,14 @@ func commitCIDestroy(queryist cli.Queryist, sqlCtx *sql.Context, tableNames []do // are staged before parent tables for i := len(tableNames) - 1; i >= 0; i-- { tn := tableNames[i] - _, err := commands.GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("CALL DOLT_ADD('%s');", tn.Name)) + _, err := cli.GetRowsForSql(queryist, sqlCtx, fmt.Sprintf("CALL DOLT_ADD('%s');", tn.Name)) if err != nil { return err } } query := fmt.Sprintf("CALL DOLT_COMMIT('-m', 'Successfully destroyed Dolt CI', '--author', '%s <%s>');", name, email) - _, err := commands.GetRowsForSql(queryist, sqlCtx, query) + _, err := cli.GetRowsForSql(queryist, sqlCtx, query) return err } @@ -125,14 +124,14 @@ func commitCIInit(sqlCtx *sql.Context, queryist cli.Queryist, tableNames []doltd for i := len(tableNames) - 1; i >= 0; i-- { tn := tableNames[i] query := fmt.Sprintf("CALL DOLT_ADD('%s');", tn.Name) - _, err := commands.GetRowsForSql(queryist, sqlCtx, query) + _, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return err } } query := fmt.Sprintf("CALL DOLT_COMMIT('-m', 'Successfully initialized Dolt CI', '--author', '%s <%s>');", name, email) - _, err := commands.GetRowsForSql(queryist, sqlCtx, query) + _, err := cli.GetRowsForSql(queryist, sqlCtx, query) return err } @@ -147,7 +146,7 @@ func DestroyDoltCITables(queryist cli.Queryist, sqlCtx *sql.Context, name, email ciTables := ExpectedDoltCITablesOrdered.ActiveTableNames() for _, tableName := range ciTables { query := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName.Name) - _, err := commands.GetRowsForSql(queryist, sqlCtx, query) + _, err := cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return err } @@ -179,7 +178,7 @@ func CreateDoltCITables(queryist cli.Queryist, sqlCtx *sql.Context, name, email } for _, query := range orderedCreateTableQueries { - _, err = commands.GetRowsForSql(queryist, sqlCtx, query) + _, err = cli.GetRowsForSql(queryist, sqlCtx, query) if err != nil { return err }