Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions go/cmd/dolt/cli/query_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cli

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
)

// GetInt8ColAsBool returns the value of an int8 column as a bool
// This is necessary because Queryist may return an int8 column as a bool (when using SQLEngine)
// or as a string (when using ConnectionQueryist).
func GetInt8ColAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case int8:
return v != 0, nil
case string:
return v != "0", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting int8", v)
}
}

// SetSystemVar sets the @@dolt_show_system_tables variable if necessary, and returns a function
// resetting the variable for after the commands completion, if necessary.
func SetSystemVar(queryist Queryist, sqlCtx *sql.Context, newVal bool) (func() error, error) {
_, rowIter, _, err := queryist.Query(sqlCtx, "SHOW VARIABLES WHERE VARIABLE_NAME='dolt_show_system_tables'")
if err != nil {
return nil, err
}

row, err := sql.RowIterToRows(sqlCtx, rowIter)
if err != nil {
return nil, err
}
prevVal, err := GetInt8ColAsBool(row[0][1])
if err != nil {
return nil, err
}

var update func() error
if newVal != prevVal {
query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", newVal)
_, _, _, err = queryist.Query(sqlCtx, query)
update = func() error {
query := fmt.Sprintf("SET @@dolt_show_system_tables = %t", prevVal)
_, _, _, err := queryist.Query(sqlCtx, query)
return err
}
}

return update, err
}

func GetRowsForSql(queryist Queryist, sqlCtx *sql.Context, query string) ([]sql.Row, error) {
_, rowIter, _, err := queryist.Query(sqlCtx, query)
if err != nil {
return nil, err
}
rows, err := sql.RowIterToRows(sqlCtx, rowIter)
if err != nil {
return nil, err
}

return rows, nil
}
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (cmd CheckoutCmd) Exec(ctx context.Context, commandStr string, args []strin
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}

rows, err := GetRowsForSql(queryEngine, sqlCtx, sqlQuery)
rows, err := cli.GetRowsForSql(queryEngine, sqlCtx, sqlQuery)

if err != nil {
// In fringe cases the server can't start because the default branch doesn't exist, `dolt checkout <existing branch>`
Expand Down
8 changes: 4 additions & 4 deletions go/cmd/dolt/commands/cherry-pick.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ func cherryPick(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgPa
hint: commit your changes (dolt commit -am \"<message>\") or reset them (dolt reset --hard) to proceed.`)
}

_, err = GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts = 1")
_, err = cli.GetRowsForSql(queryist, sqlCtx, "set @@dolt_allow_commit_conflicts = 1")
if err != nil {
return fmt.Errorf("error: failed to set @@dolt_allow_commit_conflicts: %w", err)
}

_, err = GetRowsForSql(queryist, sqlCtx, "set @@dolt_force_transaction_commit = 1")
_, err = cli.GetRowsForSql(queryist, sqlCtx, "set @@dolt_force_transaction_commit = 1")
if err != nil {
return fmt.Errorf("error: failed to set @@dolt_force_transaction_commit: %w", err)
}
Expand All @@ -156,7 +156,7 @@ hint: commit your changes (dolt commit -am \"<message>\") or reset them (dolt re
if err != nil {
return fmt.Errorf("error: failed to interpolate query: %w", err)
}
rows, err := GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
errorText := err.Error()
switch {
Expand Down Expand Up @@ -225,7 +225,7 @@ hint: commit your changes (dolt commit -am \"<message>\") or reset them (dolt re

func cherryPickAbort(queryist cli.Queryist, sqlCtx *sql.Context) error {
query := "call dolt_cherry_pick('--abort')"
_, err := GetRowsForSql(queryist, sqlCtx, query)
_, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
errorText := err.Error()
switch errorText {
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/ci/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func runCIQuery(queryist cli.Queryist, sqlCtx *sql.Context, step dolt_ci.Step, q
return nil, fmt.Errorf("Could not find saved query: %s", step.SavedQueryName.Value)
}

rows, err := commands.GetRowsForSql(queryist, sqlCtx, query)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
statementErr := fmt.Sprintf("Ran query: %s", query)
queryErr := fmt.Sprintf("Query error: %s", err.Error())
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/ci/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func updateConfigQueryStatements(config *dolt_ci.WorkflowConfig, savedQueries ma

func getSavedQueries(sqlCtx *sql.Context, queryist cli.Queryist) (map[string]string, error) {
savedQueries := make(map[string]string)
resetFunc, err := commands.SetSystemVar(queryist, sqlCtx, true)
resetFunc, err := cli.SetSystemVar(queryist, sqlCtx, true)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/clean.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (cmd CleanCmd) Exec(ctx context.Context, commandStr string, args []string,
}
}

_, err = GetRowsForSql(queryist, sqlCtx, query)
_, err = cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
cli.Println(err.Error())
return 1
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/dolt/commands/cnfcmds/auto_resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/gocraft/dbr/v2/dialect"

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
)

type AutoResolveStrategy int
Expand Down Expand Up @@ -53,7 +52,7 @@ func AutoResolveTables(queryist cli.Queryist, sqlCtx *sql.Context, strategy Auto
if err != nil {
return fmt.Errorf("error interpolating resolve conflicts query for table %s: %w", tableName, err)
}
_, err = commands.GetRowsForSql(queryist, sqlCtx, q)
_, err = cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return fmt.Errorf("error resolving conflicts for table %s: %w", tableName, err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/dolt/commands/cnfcmds/cat.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func writeConflictResults(
func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (mergeStatus, error) {
ms := mergeStatus{}
q := "select * from dolt_merge_status;"
rows, err := commands.GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return ms, err
}
Expand All @@ -315,7 +315,7 @@ func getMergeStatus(queryist cli.Queryist, sqlCtx *sql.Context) (mergeStatus, er

func getSchemaConflictsExist(queryist cli.Queryist, sqlCtx *sql.Context) (bool, error) {
q := "select * from dolt_schema_conflicts limit 1;"
rows, err := commands.GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return false, err
}
Expand All @@ -332,7 +332,7 @@ func getTableDataConflictsExist(queryist cli.Queryist, sqlCtx *sql.Context, tabl
if err != nil {
return false, err
}
rows, err := commands.GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return false, err
}
Expand Down
20 changes: 10 additions & 10 deletions go/cmd/dolt/commands/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, _
defer closeFunc()
}

updateSystemVar, err := SetSystemVar(queryist, sqlCtx, apr.Contains(cli.SystemFlag))
updateSystemVar, err := cli.SetSystemVar(queryist, sqlCtx, apr.Contains(cli.SystemFlag))
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
Expand Down Expand Up @@ -382,7 +382,7 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string)
if err != nil {
return nil, fmt.Errorf("error interpolating query: %w", err)
}
rows, err := GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -415,7 +415,7 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string)
if err != nil {
return nil, fmt.Errorf("error interpolating query: %w", err)
}
result, err := GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
result, err := cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
if err != nil {
return nil, fmt.Errorf("error getting system table %s: %w", sysTable, err)
}
Expand Down Expand Up @@ -546,7 +546,7 @@ func getCommonAncestor(queryist cli.Queryist, sqlCtx *sql.Context, c1, c2 string
if err != nil {
return "", fmt.Errorf("error interpolating query: %w", err)
}
rows, err := GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -675,7 +675,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex
if err != nil {
return nil, fmt.Errorf("error: unable to interpolate query: %w", err)
}
schemaDiffRows, err := GetRowsForSql(queryist, sqlCtx, q)
schemaDiffRows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return nil, fmt.Errorf("error: unable to get schema diff from %s to %s: %w", fromRef, toRef, err)
}
Expand Down Expand Up @@ -714,7 +714,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex
if err != nil {
return nil, fmt.Errorf("error: unable to interpolate dolt_patch query: %w", err)
}
patchRows, err := GetRowsForSql(queryist, sqlCtx, q)
patchRows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return nil, fmt.Errorf("error: unable to get dolt_patch rows from %s to %s: %w", fromRef, toRef, err)
}
Expand All @@ -741,7 +741,7 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex

func getDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Context, fromRef, toRef string) ([]diff.TableDeltaSummary, error) {
q, err := dbr.InterpolateForDialect("select * from dolt_diff_summary(?, ?)", []interface{}{fromRef, toRef}, dialect.MySQL)
dataDiffRows, err := GetRowsForSql(queryist, sqlCtx, q)
dataDiffRows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return nil, fmt.Errorf("error: unable to get diff summary from %s to %s: %w", fromRef, toRef, err)
}
Expand Down Expand Up @@ -907,7 +907,7 @@ func getTableSchemaAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableName s
if err != nil {
return sch, createStmt, fmt.Errorf("error interpolating query: %w", err)
}
rows, err = GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
rows, err = cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
if err != nil {
return sch, createStmt, fmt.Errorf("error: unable to get create table statement for table '%s': %w", tableName, err)
}
Expand Down Expand Up @@ -952,7 +952,7 @@ func getDatabaseSchemaAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableNam
if err != nil {
return "", fmt.Errorf("error interpolating query: %w", err)
}
rows, err = GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
rows, err = cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
if err != nil {
return "", fmt.Errorf("error: unable to get create database statement for database '%s': %w", tableName, err)
}
Expand Down Expand Up @@ -1047,7 +1047,7 @@ func getTableDiffStats(queryist cli.Queryist, sqlCtx *sql.Context, tableName, fr
if err != nil {
return nil, fmt.Errorf("error interpolating query: %w", err)
}
rows, err := GetRowsForSql(queryist, sqlCtx, q)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, q)
if err != nil {
return nil, fmt.Errorf("error running diff stats query: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/dolt/commands/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (cmd LogCmd) logWithLoggerFunc(ctx context.Context, commandStr string, args
if err != nil {
return handleErrAndExit(err)
}
logRows, err := GetRowsForSql(queryist, sqlCtx, query)
logRows, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
return handleErrAndExit(err)
}
Expand All @@ -140,7 +140,7 @@ func collectRevisions(apr *argparser.ArgParseResults, queryist cli.Queryist, sql
tablesIndex++
revisions[arg] = true
} else {
_, err := GetRowsForSql(queryist, sqlCtx, "select hashof('"+arg+"')")
_, err := cli.GetRowsForSql(queryist, sqlCtx, "select hashof('"+arg+"')")

// Once we get a non-revision argument, we treat the remaining args as tables
if _, ok := revisions[arg]; ok || err != nil {
Expand Down Expand Up @@ -288,7 +288,7 @@ func getExistingTables(revisions []string, queryist cli.Queryist, sqlCtx *sql.Co
}

for _, rev := range revisions {
rows, err := GetRowsForSql(queryist, sqlCtx, "show tables as of '"+rev+"'")
rows, err := cli.GetRowsForSql(queryist, sqlCtx, "show tables as of '"+rev+"'")
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/dolt/commands/ls.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (cmd LsCmd) Exec(ctx context.Context, commandStr string, args []string, dEn
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
rows, err := GetRowsForSql(queryist, sqlCtx, query)
rows, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
Expand Down Expand Up @@ -168,7 +168,7 @@ func printUserTables(tableNames []string, apr *argparser.ArgParseResults, queryi
label = "working set"
} else {
query := fmt.Sprintf("select hashof('%s')", apr.Arg(0))
row, err := GetRowsForSql(queryist, sqlCtx, query)
row, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
return err
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func printUserTables(tableNames []string, apr *argparser.ArgParseResults, queryi

func printTableVerbose(table string, queryist cli.Queryist, sqlCtx *sql.Context) error {
query := fmt.Sprintf("select count(*) from `%s`", table)
row, err := GetRowsForSql(queryist, sqlCtx, query)
row, err := cli.GetRowsForSql(queryist, sqlCtx, query)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions go/cmd/dolt/commands/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func printMergeStats(fastForward bool,
// calculateMergeConflicts calculates the count of conflicts that occurred during the merge. Returns a map of table name to MergeStats,
// a bool indicating whether there were any conflicts, and a bool indicating whether calculation was successful.
func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeStats map[string]*merge.MergeStats) (map[string]*merge.MergeStats, bool, error) {
dataConflicts, err := GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_conflicts FROM dolt_conflicts")
dataConflicts, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_conflicts FROM dolt_conflicts")
if err != nil {
return nil, false, err
}
Expand All @@ -429,7 +429,7 @@ func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeSt
}
}

schemaConflicts, err := GetRowsForSql(queryist, sqlCtx, "SELECT table_name FROM dolt_schema_conflicts")
schemaConflicts, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT table_name FROM dolt_schema_conflicts")
if err != nil {
return nil, false, err
}
Expand All @@ -442,7 +442,7 @@ func calculateMergeConflicts(queryist cli.Queryist, sqlCtx *sql.Context, mergeSt
}
}

constraintViolations, err := GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_violations FROM dolt_constraint_violations")
constraintViolations, err := cli.GetRowsForSql(queryist, sqlCtx, "SELECT `table`, num_violations FROM dolt_constraint_violations")
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -674,7 +674,7 @@ func fillStringWithChar(ch rune, strLen int) string {
}

func handleMergeErr(sqlCtx *sql.Context, queryist cli.Queryist, mergeErr error, hasConflicts, hasConstraintViolations bool, usage cli.UsagePrinter) int {
unmergedTables, err := GetRowsForSql(queryist, sqlCtx, "select unmerged_tables from dolt_merge_status")
unmergedTables, err := cli.GetRowsForSql(queryist, sqlCtx, "select unmerged_tables from dolt_merge_status")
if err != nil {
cli.PrintErrln(err.Error())
return 1
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/merge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (cmd MergeBaseCmd) Exec(ctx context.Context, commandStr string, args []stri
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}

row, err := GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
row, err := cli.GetRowsForSql(queryist, sqlCtx, interpolatedQuery)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/dolt/commands/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func getRemoteHashForPull(apr *argparser.ArgParseResults, sqlCtx *sql.Context, q
remote = apr.Args[0]
}

rows, err := GetRowsForSql(queryist, sqlCtx, "select name from dolt_remote_branches")
rows, err := cli.GetRowsForSql(queryist, sqlCtx, "select name from dolt_remote_branches")
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -334,7 +334,7 @@ func getRemoteHashForPull(apr *argparser.ArgParseResults, sqlCtx *sql.Context, q

// getDefaultRemote gets the name of the default remote.
func getDefaultRemote(sqlCtx *sql.Context, queryist cli.Queryist) (string, error) {
rows, err := GetRowsForSql(queryist, sqlCtx, "select name from dolt_remotes")
rows, err := cli.GetRowsForSql(queryist, sqlCtx, "select name from dolt_remotes")
if err != nil {
return "", err
}
Expand Down
Loading
Loading