diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 1510115afb7..bf77226cc3e 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -682,10 +682,15 @@ func buildBatchSqlErr(stmtStartLine int, query string, err error) error { // be updated by any queries which were processed. func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResultFormat) error { _ = iohelp.WriteLine(cli.CliOut, welcomeMsg) - historyFile := filepath.Join(".sqlhistory") // history file written to working dir - initialPrompt := fmt.Sprintf("%s> ", sqlCtx.GetCurrentDatabase()) - initialMultilinePrompt := fmt.Sprintf(fmt.Sprintf("%%%ds", len(initialPrompt)), "-> ") + + db, branch, _ := getDBBranchFromSession(sqlCtx, qryist) + dirty := false + if branch != "" { + dirty, _ = isDirty(sqlCtx, qryist) + } + + initialPrompt, initialMultilinePrompt := formattedPrompts(db, branch, dirty) rlConf := readline.Config{ Prompt: initialPrompt, @@ -764,6 +769,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu } var nextPrompt string + var multiPrompt string var sqlSch sql.Schema var rowIter sql.RowIter @@ -789,11 +795,14 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu } } - db, ok := getDBFromSession(sqlCtx, qryist) + db, branch, ok := getDBBranchFromSession(sqlCtx, qryist) if ok { sqlCtx.SetCurrentDatabase(db) } - nextPrompt = fmt.Sprintf("%s> ", sqlCtx.GetCurrentDatabase()) + if branch != "" { + dirty, _ = isDirty(sqlCtx, qryist) + } + nextPrompt, multiPrompt = formattedPrompts(db, branch, dirty) return true }() @@ -803,7 +812,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu } shell.SetPrompt(nextPrompt) - shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(nextPrompt)), "-> ")) + shell.SetMultiPrompt(multiPrompt) }) shell.Run() @@ -812,30 +821,96 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu return nil } -// getDBFromSession returns the current database name for the session, handling all the errors along the way by printing -// red error messages to the CLI. If there was an issue getting the db name, the second return value is false. -func getDBFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, ok bool) { - _, resp, err := qryist.Query(sqlCtx, "select database()") +// formattedPrompts returns the prompt and multiline prompt for the current session. If the db is empty, the prompt will +// be "> ", otherwise it will be "db> ". If the branch is empty, the multiline prompt will be "-> ", left padded for +// alignment with the prompt. +func formattedPrompts(db, branch string, dirty bool) (string, string) { + if db == "" { + return "> ", "-> " + } + if branch == "" { + // +2 Allows for the "->" to lineup correctly + multi := fmt.Sprintf(fmt.Sprintf("%%%ds", len(db)+2), "-> ") + cyanDb := color.CyanString(db) + return fmt.Sprintf("%s> ", cyanDb), multi + } + + // +3 is for the "/" and "->" to lineup correctly + promptLen := len(db) + len(branch) + 3 + dirtyStr := "" + if dirty { + dirtyStr = color.RedString("*") + promptLen += 1 + } + + multi := fmt.Sprintf(fmt.Sprintf("%%%ds", promptLen), "-> ") + + cyanDb := color.CyanString(db) + yellowBr := color.YellowString(branch) + return fmt.Sprintf("%s/%s%s> ", cyanDb, yellowBr, dirtyStr), multi +} + +// getDBBranchFromSession returns the current database name and current branch for the session, handling all the errors +// along the way by printing red error messages to the CLI. If there was an issue getting the db name, the ok return +// value will be false and the strings will be empty. +func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, branch string, ok bool) { + _, resp, err := qryist.Query(sqlCtx, "select database() as db, active_branch() as branch") if err != nil { - cli.Println(color.RedString("Failure to get DB Name for session" + err.Error())) - return db, false + cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error())) + return db, branch, false } - // Expect single row/single column result with the db name. + // Expect single row result, with two columns: db name, branch name. row, err := resp.Next(sqlCtx) if err != nil { - cli.Println(color.RedString("Failure to get DB Name for session" + err.Error())) - return db, false + cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error())) + return db, branch, false } - if len(row) != 1 { - cli.Println(color.RedString("Failure to get DB Name for session" + err.Error())) - return db, false + if len(row) != 2 { + cli.Println(color.RedString("Runtime error. Invalid column count.")) + return db, branch, false + } + + if row[1] == nil { + branch = "" + } else { + branch = row[1].(string) } if row[0] == nil { db = "" } else { db = row[0].(string) + + // It is possible to `use mydb/branch`, and as far as your session is concerned your database is mydb/branch. We + // allow that, but also want to show the user the branch name in the prompt. So we munge the DB in this case. + if strings.HasSuffix(db, "/"+branch) { + db = db[:len(db)-len(branch)-1] + } } - return db, true + + return db, branch, true +} + +// isDirty returns true if the workspace is dirty, false otherwise. This function _assumes_ you are on a database +// with a branch. If you are not, you will get an error. +func isDirty(sqlCtx *sql.Context, qryist cli.Queryist) (bool, error) { + _, resp, err := qryist.Query(sqlCtx, "select count(table_name) > 0 as dirty from dolt_Status") + + if err != nil { + cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error())) + return false, err + } + // Expect single row result, with one boolean column. + row, err := resp.Next(sqlCtx) + if err != nil { + cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error())) + return false, err + } + if len(row) != 1 { + cli.Println(color.RedString("Runtime error. Invalid column count.")) + return false, fmt.Errorf("invalid column count") + } + + return getStrBoolColAsBool(row[0]) } // Returns a new auto completer with table names, column names, and SQL keywords. diff --git a/go/cmd/dolt/commands/utils.go b/go/cmd/dolt/commands/utils.go index 588a8c1c53a..e260dbf3fa8 100644 --- a/go/cmd/dolt/commands/utils.go +++ b/go/cmd/dolt/commands/utils.go @@ -368,6 +368,19 @@ func getInt64ColAsInt64(col interface{}) (int64, error) { } } +// getStringColAsString returns the value of the input as a bool. This is required because depending on if we +// go over the wire or not we may get a string or a bool when we expect a bool. +func getStrBoolColAsBool(col interface{}) (bool, error) { + switch v := col.(type) { + case bool: + return col.(bool), nil + case string: + return strings.ToLower(col.(string)) == "true", nil + default: + return false, fmt.Errorf("unexpected type %T, was expecting bool or string", v) + } +} + func getActiveBranchName(sqlCtx *sql.Context, queryEngine cli.Queryist) (string, error) { query := "SELECT active_branch()" rows, err := GetRowsForSql(queryEngine, sqlCtx, query) diff --git a/go/libraries/doltcore/sqle/dfunctions/active_branch.go b/go/libraries/doltcore/sqle/dfunctions/active_branch.go index 21168abfbd8..30b51c59572 100644 --- a/go/libraries/doltcore/sqle/dfunctions/active_branch.go +++ b/go/libraries/doltcore/sqle/dfunctions/active_branch.go @@ -36,12 +36,18 @@ func NewActiveBranchFunc() sql.Expression { // Eval implements the Expression interface. func (ab *ActiveBranchFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { dbName := ctx.GetCurrentDatabase() + if dbName == "" { + // it is possible to have no current database in some contexts. + // When you first connect to a sql server, which has no databases, for example. + return nil, nil + } + dSess := dsess.DSessFromSess(ctx.Session) ddb, ok := dSess.GetDoltDB(ctx, dbName) - if !ok { - return nil, sql.ErrDatabaseNotFound.New(dbName) + // Not all databases are dolt databases. information_schema and mysql, for example. + return nil, nil } currentBranchRef, err := dSess.CWBHeadRef(ctx, dbName) diff --git a/integration-tests/bats/sql-delimiter.expect b/integration-tests/bats/sql-delimiter.expect index d0b6b507e7a..563f8d2a62b 100755 --- a/integration-tests/bats/sql-delimiter.expect +++ b/integration-tests/bats/sql-delimiter.expect @@ -3,32 +3,32 @@ set timeout 2 spawn dolt sql expect { - "doltsql> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT);\r"; } + "> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT);\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (0,0);\r"; } + "> " { send "INSERT INTO test VALUES (0,0);\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "DELIMITER $$\r"; } + "> " { send "DELIMITER $$\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (1,1)$$\r"; } + "> " { send "INSERT INTO test VALUES (1,1)$$\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "delimiter #\r"; } + "> " { send "delimiter #\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "CREATE TRIGGER tt BEFORE INSERT ON test FOR EACH ROW\r"; } + "> " { send "CREATE TRIGGER tt BEFORE INSERT ON test FOR EACH ROW\r"; } timeout { exit 1; } failed { exit 1; } } @@ -53,7 +53,7 @@ expect { failed { exit 1; } } expect { - "doltsql> " { send "DeLiMiTeR ;\r"; } + "> " { send "DeLiMiTeR ;\r"; } timeout { exit 1; } failed { exit 1; } } diff --git a/integration-tests/bats/sql-shell-prompt.expect b/integration-tests/bats/sql-shell-prompt.expect new file mode 100755 index 00000000000..1b4f727a96d --- /dev/null +++ b/integration-tests/bats/sql-shell-prompt.expect @@ -0,0 +1,50 @@ +#!/usr/bin/expect + +set timeout 5 +spawn dolt sql + +expect { + -re "> " { send "create database mydb;\r"; } + timeout { exit 1; } + failed { exit 1; } +} +expect { + -re "> " { send "use mydb;\r"; } + timeout { exit 1; } + failed { exit 1; } +} +expect { + -re ".*mydb.*/.*main.*> " { send "create table tbl (i int);\r"; } + timeout { exit 1; } + failed { exit 1; } +} + +# Dirty workspace should show in prompt as a "*" before the ">" +# (all the .* instances here are to account for ansi colors chars. +expect { + -re ".*mydb.*/.*main.*\\*.*> " { send "call dolt_commit('-Am', 'msg');\r"; } + timeout { exit 1; } + failed { exit 1; } +} + +expect { + -re ".*mydb.*/.*main.*> " { send "call dolt_checkout('-b','other','HEAD');\r"; } + timeout { exit 1; } + failed { exit 1; } +} + +expect { + -re ".*mydb.*/.*main.*> " { send "use mysql;\r"; } + timeout { exit 1; } + failed { exit 1; } +} + +# using a non dolt db should result in a prompt without a slash. The brackets +# are required to get expect to properly parse this regex. +expect { + -re {.*mysql[^\\/]*> } { send "exit;\r"; } + timeout { exit 1; } + failed { exit 1; } +} + +expect eof diff --git a/integration-tests/bats/sql-shell.bats b/integration-tests/bats/sql-shell.bats index dad617a59d8..4e01ffd2636 100644 --- a/integration-tests/bats/sql-shell.bats +++ b/integration-tests/bats/sql-shell.bats @@ -67,6 +67,21 @@ teardown() { [[ "$output" =~ "+---------------------" ]] || false } +# bats test_tags=no_lambda +@test "sql-shell: sql shell prompt updates" { + skiponwindows "Need to install expect and make this script work on windows." + if [ "$SQL_ENGINE" = "remote-engine" ]; then + skip "Presently sql command will not connect to remote server due to lack of lock file where there are not DBs." + fi + + # start in an empty directory + rm -rf .dolt + mkdir sql_shell_test + cd sql_shell_test + + $BATS_TEST_DIRNAME/sql-shell-prompt.expect +} + # bats test_tags=no_lambda @test "sql-shell: shell works after failing query" { skiponwindows "Need to install expect and make this script work on windows." diff --git a/integration-tests/bats/sql-unique-error.expect b/integration-tests/bats/sql-unique-error.expect index 68573794234..645115f6f06 100755 --- a/integration-tests/bats/sql-unique-error.expect +++ b/integration-tests/bats/sql-unique-error.expect @@ -4,34 +4,34 @@ set timeout 2 spawn dolt sql expect { - "doltsql> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE);\r"; } + "> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE);\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (0,0);\r"; } + "> " { send "INSERT INTO test VALUES (0,0);\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (1,0);\r"; } + "> " { send "INSERT INTO test VALUES (1,0);\r"; } timeout { exit 1; } "UNIQUE" { exp_continue; } failed { exp_continue; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (1,1);\r"; } + "> " { send "INSERT INTO test VALUES (1,1);\r"; } timeout { exit 1; } failed { exit 1; } } expect { - "doltsql> " { send "INSERT INTO test VALUES (2,2);\r"; } - timeout { exit 1; } - failed { exit 1; } -} + "> " { send "INSERT INTO test VALUES (2,2);\r"; } + timeout { exit 1; } + failed { exit 1; } + } expect eof diff --git a/integration-tests/bats/sql-use.expect b/integration-tests/bats/sql-use.expect index 9371534a98c..d12502730e8 100644 --- a/integration-tests/bats/sql-use.expect +++ b/integration-tests/bats/sql-use.expect @@ -7,56 +7,56 @@ spawn dolt sql # error output includes the line of the failed test expectation expect { - "*doltsql> " { send -- "use `doltsql/test`;\r"; } + -re ".*doltsql.*/.*main.*> " { send -- "use `doltsql/test`;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*doltsql/test> " { send -- "show tables;\r"; } + -re ".*doltsql.*/.*test.*> " { send -- "show tables;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*doltsql/test> " { send -- "use information_schema;\r"; } + -re ".*doltsql.*/.*test.*> " { send -- "use information_schema;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*information_schema> " { send -- "show tables;\r"; } + -re ".*information_schema.*> " { send -- "show tables;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*information_schema> " { send -- "CREATE DATABASE mydb;\r"; } + -re ".*information_schema.*> " { send -- "CREATE DATABASE mydb;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*information_schema> " { send -- "use db1;\r"; } + -re ".*information_schema.*> " { send -- "use db1;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - -re "|.*db1.*|\r.*db1> " { send -- "select database();\r"; } + -re "|.*db1.*|\r.*db1.*> " { send -- "select database();\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*db1>" { send -- "use db2;\r"; } + -re ".*db1.*/.*main.*>" { send -- "use db2;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "*db2> " { send -- "select database();\r"; } + -re ".*db2.*/.*main.*> " { send -- "select database();\r"; } timeout { puts "$TESTFAILURE"; } } expect { - -re "|.*db2.*|.*\rdb2>" { send -- "use mydb;\r"; } + -re "|.*db2.*|.*\rdb2/main>" { send -- "use mydb;\r"; } timeout { puts "$TESTFAILURE"; } } expect { - "mydb> " { send -- "exit ;\r"; } + -re ".*mydb.*/.*main.*> " { send -- "exit ;\r"; } timeout { puts "$TESTFAILURE"; } }