diff --git a/go/cmd/dolt/commands/stash.go b/go/cmd/dolt/commands/stash.go index 5d06b68814d..92f2ac69dba 100644 --- a/go/cmd/dolt/commands/stash.go +++ b/go/cmd/dolt/commands/stash.go @@ -38,6 +38,7 @@ const ( PushCmdRef = "push" PopCmdRef = "pop" DropCmdRef = "drop" + ApplyCmdRef = "apply" ClearCmdRef = "clear" ListCmdRef = "list" ) @@ -116,19 +117,35 @@ func (cmd StashCmd) Exec(ctx context.Context, commandStr string, args []string, } } + // List queries a system table, unlike the procedure the other commands use, so we handle it in a special case + if subcommand == ListCmdRef { + return stashList(ctx, cliCtx) + } + + // Pre-query: Pop, Apply, and Drop commands need to confirm that the given id is valid. + // We'll also check that the subcommand is valid + var stash *doltdb.Stash switch subcommand { - case PushCmdRef: - err = stashPush(queryist.Queryist, queryist.Context, apr, subcommand) - case PopCmdRef, DropCmdRef: - err = stashRemove(queryist.Queryist, queryist.Context, cliCtx, apr, subcommand, idx) - case ListCmdRef: - err = stashList(ctx, cliCtx) - case ClearCmdRef: - err = stashClear(queryist.Queryist, queryist.Context, apr, subcommand) + case PopCmdRef, ApplyCmdRef, DropCmdRef: + stash, err = validateStashAtIdx(queryist.Queryist, queryist.Context, idx) + if err != nil { + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 + } + case PushCmdRef, ClearCmdRef: default: err = fmt.Errorf("unknown stash subcommand %s", subcommand) + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 } + // Now build the call to DOLT_STASH and run it + interpolatedQuery, err := generateStashSql(apr, subcommand) + if err != nil { + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 + } + _, rowIter, _, err := queryist.Queryist.Query(queryist.Context, interpolatedQuery) if err != nil { cli.PrintErrln(errhand.VerboseErrorFromError(err)) if strings.Contains(err.Error(), "No local changes to save") { @@ -136,83 +153,77 @@ func (cmd StashCmd) Exec(ctx context.Context, commandStr string, args []string, } return 1 } - return 0 -} -func stashPush(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) error { - rowIter, err := stashQuery(queryist, sqlCtx, apr, subcommand) - if err != nil { - return err + // Finally, print out any relevant status message and consume the row iterator + switch subcommand { + case PushCmdRef: + stashes, err := getStashesSQL(queryist.Context, queryist.Queryist, 1) + if err != nil { + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 + } + stash := stashes[0] + cli.Println(fmt.Sprintf("Saved working directory and index state WIP on %s: %s %s", stash.BranchReference, stash.CommitHash, stash.Description)) + case PopCmdRef: + err = PrintStatus(queryist.Context, false, cliCtx) + if err != nil { + cli.Println("The stash entry is kept in case you need it again.") + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 + } + cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stash.CommitHash)) + case DropCmdRef: + cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stash.CommitHash)) + case ApplyCmdRef: + err = PrintStatus(queryist.Context, false, cliCtx) + if err != nil { + cli.Println("The stash entry is kept in case you need it again.") + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 + } } - stashes, err := getStashesSQL(sqlCtx, queryist, 1) + _, err = sql.RowIterToRows(queryist.Context, rowIter) + if err != nil { - return err + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 } - stash := stashes[0] - cli.Println(fmt.Sprintf("Saved working directory and index state WIP on %s: %s %s", stash.BranchReference, stash.CommitHash, stash.Description)) - _, err = sql.RowIterToRows(sqlCtx, rowIter) - return err + return 0 } -func stashRemove(queryist cli.Queryist, sqlCtx *sql.Context, cliCtx cli.CliContext, apr *argparser.ArgParseResults, subcommand string, idx int) error { +// validateStashAtIdx verifies that the given number is within the range of stashes, then returns the stash at that id. +func validateStashAtIdx(queryist cli.Queryist, sqlCtx *sql.Context, idx int) (*doltdb.Stash, error) { stashes, err := getStashesSQL(sqlCtx, queryist, 0) if err != nil { - return err + return nil, err } if len(stashes) == 0 { - return fmt.Errorf("No stash entries found.") + return nil, fmt.Errorf("no stash entries found") } if len(stashes)-1 < idx { - return fmt.Errorf("stash index stash@{%d} does not exist", idx) - } - - interpolatedQuery, err := generateStashSql(apr, subcommand) - if err != nil { - return err - } - _, rowIter, _, err := queryist.Query(sqlCtx, interpolatedQuery) - if err != nil { - return err + return nil, fmt.Errorf("stash index stash@{%d} does not exist", idx) } - if subcommand == PopCmdRef { - ret := StatusCmd{}.Exec(sqlCtx, StatusCmd{}.Name(), []string{}, nil, cliCtx) - if ret != 0 { - cli.Println("The stash entry is kept in case you need it again.") - return err - } - } - - cli.Println(fmt.Sprintf("Dropped refs/stash@{%v} (%s)", idx, stashes[idx].CommitHash)) - _, err = sql.RowIterToRows(sqlCtx, rowIter) - return err + return stashes[idx], nil } - -func stashList(ctx context.Context, cliCtx cli.CliContext) error { +func stashList(ctx context.Context, cliCtx cli.CliContext) int { queryist, err := cliCtx.QueryEngine(ctx) if err != nil { - return err + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 } stashes, err := getStashesSQL(queryist.Context, queryist.Queryist, 0) if err != nil { - return err + cli.PrintErrln(errhand.VerboseErrorFromError(err)) + return 1 } for _, stash := range stashes { cli.Println(fmt.Sprintf("%s: WIP on %s: %s %s", stash.Name, stash.BranchReference, stash.CommitHash, stash.Description)) } - return nil -} - -func stashClear(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) error { - rowIter, err := stashQuery(queryist, sqlCtx, apr, subcommand) - if err != nil { - return err - } - _, err = sql.RowIterToRows(sqlCtx, rowIter) - return err + return 0 } // getStashesSQL queries the dolt_stashes system table to return the requested number of stashes. A limit of 0 will get all stashes @@ -287,20 +298,6 @@ func generateStashSql(apr *argparser.ArgParseResults, subcommand string) (string return interpolatedQuery, err } -func stashQuery(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, subcommand string) (sql.RowIter, error) { - interpolatedQuery, err := generateStashSql(apr, subcommand) - if err != nil { - return nil, err - } - - _, rowIter, _, err := queryist.Query(sqlCtx, interpolatedQuery) - if err != nil { - return nil, err - } - - return rowIter, nil -} - func parseStashIndex(stashID string) (int, error) { var err error stashID = strings.TrimSuffix(strings.TrimPrefix(stashID, "stash@{"), "}") diff --git a/go/cmd/dolt/commands/status.go b/go/cmd/dolt/commands/status.go index 0fd5b0500ac..740696be323 100644 --- a/go/cmd/dolt/commands/status.go +++ b/go/cmd/dolt/commands/status.go @@ -20,7 +20,6 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/types" "github.com/fatih/color" "github.com/dolthub/dolt/go/cmd/dolt/cli" @@ -43,21 +42,18 @@ type printData struct { conflictsPresent, showIgnoredTables, - statusPresent, mergeActive bool stagedTables, unstagedTables, untrackedTables, - filteredUntrackedTables, - unmergedTables map[string]string + filteredUntrackedTables map[string]string constraintViolationTables, dataConflictTables, schemaConflictTables map[string]bool - ignorePatterns doltdb.IgnorePatterns - ignoredTables doltdb.IgnoredTables + ignoredTables doltdb.IgnoredTables } var statusDocs = cli.CommandDocumentationContent{ @@ -104,27 +100,35 @@ func (cmd StatusCmd) Exec(ctx context.Context, commandStr string, args []string, showIgnoredTables := apr.Contains(cli.ShowIgnoredFlag) - // configure SQL engine - queryist, err := cliCtx.QueryEngine(ctx) + err := PrintStatus(ctx, showIgnoredTables, cliCtx) if err != nil { return handleStatusVErr(err) } + return 0 +} + +func PrintStatus(ctx context.Context, showIgnoredTables bool, cliCtx cli.CliContext) error { + // configure SQL engine + qe, err := cliCtx.QueryEngine(ctx) + if err != nil { + return err + } + // get status information from the database - pd, err := createPrintData(err, queryist.Queryist, queryist.Context, showIgnoredTables, cliCtx) + pd, err := createPrintData(qe.Queryist, qe.Context, showIgnoredTables, cliCtx) if err != nil { - return handleStatusVErr(err) + return err } err = printEverything(pd) if err != nil { - return handleStatusVErr(err) + return err } - - return 0 + return nil } -func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, showIgnoredTables bool, cliCtx cli.CliContext) (*printData, error) { +func createPrintData(queryist cli.Queryist, sqlCtx *sql.Context, showIgnoredTables bool, cliCtx cli.CliContext) (*printData, error) { var branchName string if brName, hasBranch := cliCtx.GlobalArgs().GetValue(cli.BranchParam); hasBranch { @@ -169,7 +173,7 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show return nil, err } - ahead, behind, err := getUpstreamInfo(queryist, sqlCtx, branchName, remoteName, remoteBranchName, currentBranchCommit) + ahead, behind, err := getUpstreamInfo(queryist, sqlCtx, remoteName, remoteBranchName, currentBranchCommit) if err != nil { return nil, err } @@ -180,8 +184,6 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show } statusPresent := len(statusRows) > 0 - conflictedTables := getConflictedTables(statusRows) - // sort tables into categories conflictsPresent := false stagedTables := map[string]string{} @@ -248,13 +250,10 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show if isStaged { stagedTables[tableName] = status } else { - isTableConflicted := conflictedTables[tableName] - if !isTableConflicted { - if status == "new table" { - untrackedTables[tableName] = status - } else { - unstagedTables[tableName] = status - } + if status == "new table" { + untrackedTables[tableName] = status + } else { + unstagedTables[tableName] = status } } case "schema conflict": @@ -297,35 +296,20 @@ func createPrintData(err error, queryist cli.Queryist, sqlCtx *sql.Context, show behind: behind, conflictsPresent: conflictsPresent, showIgnoredTables: showIgnoredTables, - statusPresent: statusPresent, mergeActive: mergeActive, stagedTables: stagedTables, unstagedTables: unstagedTables, untrackedTables: untrackedTables, filteredUntrackedTables: filteredUntrackedTables, - unmergedTables: unmergedTables, ignoredTables: ignoredTables, constraintViolationTables: constraintViolationTables, schemaConflictTables: schemaConflictTables, - ignorePatterns: ignorePatterns, dataConflictTables: dataConflictTables, } return &pd, nil } -func getConflictedTables(statusRows []sql.Row) map[string]bool { - conflictedTables := make(map[string]bool) - for _, row := range statusRows { - tableName := row[0].(string) - status := row[2].(string) - if status == "conflict" { - conflictedTables[tableName] = true - } - } - return conflictedTables -} - -func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, branchName string, remoteName string, upstreamBranchName string, currentBranchCommit string) (ahead int64, behind int64, err error) { +func getUpstreamInfo(queryist cli.Queryist, sqlCtx *sql.Context, remoteName string, upstreamBranchName string, currentBranchCommit string) (ahead int64, behind int64, err error) { ahead = 0 behind = 0 if len(remoteName) > 0 || len(upstreamBranchName) > 0 { @@ -679,21 +663,3 @@ func handleStatusVErr(err error) int { } return 1 } - -// getJsonDocumentColAsString returns the value of a JSONDocument column as a string -// This is necessary because Queryist may return a tinyint column as a bool (when using SQLEngine) -// or as a string (when using ConnectionQueryist). -func getJsonDocumentColAsString(sqlCtx *sql.Context, col interface{}) (string, error) { - switch v := col.(type) { - case string: - return v, nil - case types.JSONDocument: - text, err := v.JSONString() - if err != nil { - return "", err - } - return text, nil - default: - return "", fmt.Errorf("unexpected type %T, was expecting JSONDocument or string", v) - } -} diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_stash.go b/go/libraries/doltcore/sqle/dprocedures/dolt_stash.go index 480985b9f27..5be6b449cad 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_stash.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_stash.go @@ -91,6 +91,8 @@ func doDoltStash(ctx *sql.Context, args []string) (int, error) { err = doStashPop(ctx, dbData, stashName, idx) case "drop": err = doStashDrop(ctx, dbData, stashName, idx) + case "apply": + err = doStashApply(ctx, dbData, stashName, idx) case "clear": if apr.NArg() > 2 { // Clear does not take extra arguments return cmdFailure, fmt.Errorf("error: invalid arguments. Clear takes only subcommand and stash name") @@ -189,6 +191,32 @@ func doStashPop(ctx *sql.Context, dbData env.DbData[*sql.Context], stashName str return dbData.Ddb.RemoveStashAtIdx(ctx, idx, stashName) } +func doStashApply(ctx *sql.Context, dbData env.DbData[*sql.Context], stashName string, idx int) error { + headCommit, result, meta, err := handleMerge(ctx, dbData, stashName, idx) + if err != nil { + return err + } + + err = updateWorkingRoot(ctx, dbData, result.Root) + if err != nil { + return err + } + + roots, err := getRoots(ctx, dbData, headCommit) + if err != nil { + return err + } + + // added tables need to be staged + // since these tables are coming from a stash, don't filter for ignored table names. + roots, err = actions.StageTables(ctx, roots, doltdb.ToTableNames(meta.TablesToStage, doltdb.DefaultSchemaName), false) + if err != nil { + return err + } + + return updateWorkingSetFromRoots(ctx, dbData, roots) +} + func doStashDrop(ctx *sql.Context, dbData env.DbData[*sql.Context], stashName string, idx int) error { return dbData.Ddb.RemoveStashAtIdx(ctx, idx, stashName) } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_stash.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_stash.go index 26d4941217e..825b32380ed 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_stash.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_stash.go @@ -583,4 +583,81 @@ var DoltStashTests = []queries.ScriptTest{ }, }, }, + { + Name: "simple stash apply", + SetUpScript: []string{ + "CREATE TABLE test(id INT)", + "CALL DOLT_STASH('push', 'myStash', '-a')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * FROM DOLT_STATUS", + Expected: []sql.Row{}, + }, + { + Query: "CALL DOLT_STASH('apply', 'myStash');", + SkipResultsCheck: true, + }, + { + Query: "SELECT * FROM DOLT_STATUS", + Expected: []sql.Row{ + {"test", byte(0), "new table"}, + }, + }, + }, + }, + { + Name: "stash apply maintains stash entry", + SetUpScript: []string{ + "CALL DOLT_COMMIT('--allow-empty', '-m', 'First commit')", //Normal test initialization adds some weird commit + "CREATE TABLE test(id INT)", + "CALL DOLT_STASH('push', 'myStash', '-a')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * FROM DOLT_STASHES", + Expected: []sql.Row{ + {"myStash", "stash@{0}", "main", doltCommit, "First commit"}, + }, + }, + { + Query: "CALL DOLT_STASH('apply', 'myStash');", + SkipResultsCheck: true, + }, + { + Query: "SELECT * FROM DOLT_STASHES", + Expected: []sql.Row{ + {"myStash", "stash@{0}", "main", doltCommit, "First commit"}, + }, + }, + }, + }, + { + Name: "can apply specific stashes", + SetUpScript: []string{ + "CREATE TABLE test1 (id INT)", + "CALL DOLT_STASH('push', 'myStash', '-a')", + "CREATE TABLE test2 (id INT)", + "CALL DOLT_STASH('push', 'myStash', '-a')", + "CREATE TABLE test3 (id INT)", + "CALL DOLT_STASH('push', 'myStash', '-a')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL DOLT_STASH('apply', 'myStash', '1')", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_STASH('apply', 'myStash', 'stash@{2}')", + SkipResultsCheck: true, + }, + { + Query: "Select * from dolt_status", + Expected: []sql.Row{ + {"test2", byte(0), "new table"}, + {"test1", byte(0), "new table"}, + }, + }, + }, + }, } diff --git a/integration-tests/bats/stash.bats b/integration-tests/bats/stash.bats index eaa016eb6ab..0e9d0ec8ef8 100644 --- a/integration-tests/bats/stash.bats +++ b/integration-tests/bats/stash.bats @@ -117,7 +117,7 @@ teardown() { run dolt stash pop [ "$status" -eq 1 ] - [[ "$output" =~ "No stash entries found." ]] || false + [[ "$output" =~ "no stash entries found" ]] || false dolt sql -q "INSERT INTO test VALUES (2, 'b')" dolt stash @@ -675,3 +675,16 @@ teardown() { [[ "$output" =~ "nothing to commit, working tree clean" ]] || false [[ "$output" =~ "Dropped refs/stash@{0}" ]] || false } + +@test "stash: can apply stashes at specific indices" { + dolt sql -q "create table test1 (i int)" + dolt stash -a + dolt sql -q "create table test2 (i int)" + dolt stash -a + dolt sql -q "create table test3 (i int)" + dolt stash -a + run dolt stash apply 1 + [ "$status" -eq 0 ] + run dolt stash apply stash@{2} + [ "$status" -eq 0 ] +}