Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 69 additions & 72 deletions go/cmd/dolt/commands/stash.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
PushCmdRef = "push"
PopCmdRef = "pop"
DropCmdRef = "drop"
ApplyCmdRef = "apply"
ClearCmdRef = "clear"
ListCmdRef = "list"
)
Expand Down Expand Up @@ -116,103 +117,113 @@ func (cmd StashCmd) Exec(ctx context.Context, commandStr string, args []string,
}
}

// List queries a system table, unlike the procedure the other commands use, so we handle it in a special case
if subcommand == ListCmdRef {
return stashList(ctx, cliCtx)
}

// Pre-query: Pop, Apply, and Drop commands need to confirm that the given id is valid.
// We'll also check that the subcommand is valid
var stash *doltdb.Stash
switch subcommand {
case PushCmdRef:
err = stashPush(queryist.Queryist, queryist.Context, apr, subcommand)
case PopCmdRef, DropCmdRef:
err = stashRemove(queryist.Queryist, queryist.Context, cliCtx, apr, subcommand, idx)
case ListCmdRef:
err = stashList(ctx, cliCtx)
case ClearCmdRef:
err = stashClear(queryist.Queryist, queryist.Context, apr, subcommand)
case PopCmdRef, ApplyCmdRef, DropCmdRef:
stash, err = validateStashAtIdx(queryist.Queryist, queryist.Context, idx)
if err != nil {
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
case PushCmdRef, ClearCmdRef:
default:
err = fmt.Errorf("unknown stash subcommand %s", subcommand)
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}

// Now build the call to DOLT_STASH and run it
interpolatedQuery, err := generateStashSql(apr, subcommand)
if err != nil {
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
_, rowIter, _, err := queryist.Queryist.Query(queryist.Context, interpolatedQuery)
if err != nil {
cli.PrintErrln(errhand.VerboseErrorFromError(err))
if strings.Contains(err.Error(), "No local changes to save") {
return 0
}
return 1
}
return 0
}

func stashPush(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) error {
rowIter, err := stashQuery(queryist, sqlCtx, apr, subcommand)
if err != nil {
return err
// Finally, print out any relevant status message and consume the row iterator
switch subcommand {
case PushCmdRef:
stashes, err := getStashesSQL(queryist.Context, queryist.Queryist, 1)
if err != nil {
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
stash := stashes[0]
cli.Println(fmt.Sprintf("Saved working directory and index state WIP on %s: %s %s", stash.BranchReference, stash.CommitHash, stash.Description))
case PopCmdRef:
err = PrintStatus(queryist.Context, false, cliCtx)
if err != nil {
cli.Println("The stash entry is kept in case you need it again.")
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stash.CommitHash))
case DropCmdRef:
cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stash.CommitHash))
case ApplyCmdRef:
err = PrintStatus(queryist.Context, false, cliCtx)
if err != nil {
cli.Println("The stash entry is kept in case you need it again.")
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
}

stashes, err := getStashesSQL(sqlCtx, queryist, 1)
_, err = sql.RowIterToRows(queryist.Context, rowIter)

if err != nil {
return err
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
stash := stashes[0]
cli.Println(fmt.Sprintf("Saved working directory and index state WIP on %s: %s %s", stash.BranchReference, stash.CommitHash, stash.Description))
_, err = sql.RowIterToRows(sqlCtx, rowIter)
return err
return 0
}

func stashRemove(queryist cli.Queryist, sqlCtx *sql.Context, cliCtx cli.CliContext, apr *argparser.ArgParseResults, subcommand string, idx int) error {
// validateStashAtIdx verifies that the given number is within the range of stashes, then returns the stash at that id.
func validateStashAtIdx(queryist cli.Queryist, sqlCtx *sql.Context, idx int) (*doltdb.Stash, error) {
stashes, err := getStashesSQL(sqlCtx, queryist, 0)
if err != nil {
return err
return nil, err
}
if len(stashes) == 0 {
return fmt.Errorf("No stash entries found.")
return nil, fmt.Errorf("no stash entries found")
}
if len(stashes)-1 < idx {
return fmt.Errorf("stash index stash@{%d} does not exist", idx)
}

interpolatedQuery, err := generateStashSql(apr, subcommand)
if err != nil {
return err
}
_, rowIter, _, err := queryist.Query(sqlCtx, interpolatedQuery)
if err != nil {
return err
return nil, fmt.Errorf("stash index stash@{%d} does not exist", idx)
}

if subcommand == PopCmdRef {
ret := StatusCmd{}.Exec(sqlCtx, StatusCmd{}.Name(), []string{}, nil, cliCtx)
if ret != 0 {
cli.Println("The stash entry is kept in case you need it again.")
return err
}
}

cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stashes[idx].CommitHash))
_, err = sql.RowIterToRows(sqlCtx, rowIter)
return err
return stashes[idx], nil
}

func stashList(ctx context.Context, cliCtx cli.CliContext) error {
func stashList(ctx context.Context, cliCtx cli.CliContext) int {
queryist, err := cliCtx.QueryEngine(ctx)
if err != nil {
return err
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}

stashes, err := getStashesSQL(queryist.Context, queryist.Queryist, 0)
if err != nil {
return err
cli.PrintErrln(errhand.VerboseErrorFromError(err))
return 1
}
for _, stash := range stashes {
cli.Println(fmt.Sprintf("%s: WIP on %s: %s %s", stash.Name, stash.BranchReference, stash.CommitHash, stash.Description))
}

return nil
}

func stashClear(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) error {
rowIter, err := stashQuery(queryist, sqlCtx, apr, subcommand)
if err != nil {
return err
}
_, err = sql.RowIterToRows(sqlCtx, rowIter)
return err
return 0
}

// getStashesSQL queries the dolt_stashes system table to return the requested number of stashes. A limit of 0 will get all stashes
Expand Down Expand Up @@ -287,20 +298,6 @@ func generateStashSql(apr *argparser.ArgParseResults, subcommand string) (string
return interpolatedQuery, err
}

func stashQuery(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) (sql.RowIter, error) {
interpolatedQuery, err := generateStashSql(apr, subcommand)
if err != nil {
return nil, err
}

_, rowIter, _, err := queryist.Query(sqlCtx, interpolatedQuery)
if err != nil {
return nil, err
}

return rowIter, nil
}

func parseStashIndex(stashID string) (int, error) {
var err error
stashID = strings.TrimSuffix(strings.TrimPrefix(stashID, "stash@{"), "}")
Expand Down
82 changes: 24 additions & 58 deletions go/cmd/dolt/commands/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/fatih/color"

"github.com/dolthub/dolt/go/cmd/dolt/cli"
Expand All @@ -43,21 +42,18 @@ type printData struct {

conflictsPresent,
showIgnoredTables,
statusPresent,
mergeActive bool

stagedTables,
unstagedTables,
untrackedTables,
filteredUntrackedTables,
unmergedTables map[string]string
filteredUntrackedTables map[string]string

constraintViolationTables,
dataConflictTables,
schemaConflictTables map[string]bool

ignorePatterns doltdb.IgnorePatterns
ignoredTables doltdb.IgnoredTables
ignoredTables doltdb.IgnoredTables
}

var statusDocs = cli.CommandDocumentationContent{
Expand Down Expand Up @@ -104,27 +100,35 @@ func (cmd StatusCmd) Exec(ctx context.Context, commandStr string, args []string,

showIgnoredTables := apr.Contains(cli.ShowIgnoredFlag)

// configure SQL engine
queryist, err := cliCtx.QueryEngine(ctx)
err := PrintStatus(ctx, showIgnoredTables, cliCtx)
if err != nil {
return handleStatusVErr(err)
}

return 0
}

func PrintStatus(ctx context.Context, showIgnoredTables bool, cliCtx cli.CliContext) error {
// configure SQL engine
qe, err := cliCtx.QueryEngine(ctx)
if err != nil {
return err
}

// get status information from the database
pd, err := createPrintData(err, queryist.Queryist, queryist.Context, showIgnoredTables, cliCtx)
pd, err := createPrintData(qe.Queryist, qe.Context, showIgnoredTables, cliCtx)
if err != nil {
return handleStatusVErr(err)
return err
}

err = printEverything(pd)
if err != nil {
return handleStatusVErr(err)
return err
}

return 0
return nil
}

func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, showIgnoredTables bool, cliCtx cli.CliContext) (*printData, error) {
func createPrintData(queryist cli.Queryist, sqlCtx *sql.Context, showIgnoredTables bool, cliCtx cli.CliContext) (*printData, error) {
var branchName string

if brName, hasBranch := cliCtx.GlobalArgs().GetValue(cli.BranchParam); hasBranch {
Expand Down Expand Up @@ -169,7 +173,7 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show
return nil, err
}

ahead, behind, err := getUpstreamInfo(queryist, sqlCtx, branchName, remoteName, remoteBranchName, currentBranchCommit)
ahead, behind, err := getUpstreamInfo(queryist, sqlCtx, remoteName, remoteBranchName, currentBranchCommit)
if err != nil {
return nil, err
}
Expand All @@ -180,8 +184,6 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show
}
statusPresent := len(statusRows) > 0

conflictedTables := getConflictedTables(statusRows)

// sort tables into categories
conflictsPresent := false
stagedTables := map[string]string{}
Expand Down Expand Up @@ -248,13 +250,10 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show
if isStaged {
stagedTables[tableName] = status
} else {
isTableConflicted := conflictedTables[tableName]
if !isTableConflicted {
if status == "new table" {
untrackedTables[tableName] = status
} else {
unstagedTables[tableName] = status
}
if status == "new table" {
untrackedTables[tableName] = status
} else {
unstagedTables[tableName] = status
}
}
case "schema conflict":
Expand Down Expand Up @@ -297,35 +296,20 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show
behind: behind,
conflictsPresent: conflictsPresent,
showIgnoredTables: showIgnoredTables,
statusPresent: statusPresent,
mergeActive: mergeActive,
stagedTables: stagedTables,
unstagedTables: unstagedTables,
untrackedTables: untrackedTables,
filteredUntrackedTables: filteredUntrackedTables,
unmergedTables: unmergedTables,
ignoredTables: ignoredTables,
constraintViolationTables: constraintViolationTables,
schemaConflictTables: schemaConflictTables,
ignorePatterns: ignorePatterns,
dataConflictTables: dataConflictTables,
}
return &pd, nil
}

func getConflictedTables(statusRows []sql.Row) map[string]bool {
conflictedTables := make(map[string]bool)
for _, row := range statusRows {
tableName := row[0].(string)
status := row[2].(string)
if status == "conflict" {
conflictedTables[tableName] = true
}
}
return conflictedTables
}

func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName string, remoteName string, upstreamBranchName string, currentBranchCommit string) (ahead int64, behind int64, err error) {
func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, remoteName string, upstreamBranchName string, currentBranchCommit string) (ahead int64, behind int64, err error) {
ahead = 0
behind = 0
if len(remoteName) > 0 || len(upstreamBranchName) > 0 {
Expand Down Expand Up @@ -679,21 +663,3 @@ func handleStatusVErr(err error) int {
}
return 1
}

// getJsonDocumentColAsString returns the value of a JSONDocument column as a string
// This is necessary because Queryist may return a tinyint column as a bool (when using SQLEngine)
// or as a string (when using ConnectionQueryist).
func getJsonDocumentColAsString(sqlCtx *sql.Context, col interface{}) (string, error) {
switch v := col.(type) {
case string:
return v, nil
case types.JSONDocument:
text, err := v.JSONString()
if err != nil {
return "", err
}
return text, nil
default:
return "", fmt.Errorf("unexpected type %T, was expecting JSONDocument or string", v)
}
}
Loading
Loading