Skip to content
113 changes: 94 additions & 19 deletions go/cmd/dolt/commands/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,15 @@ func buildBatchSqlErr(stmtStartLine int, query string, err error) error {
// be updated by any queries which were processed.
func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResultFormat) error {
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)

historyFile := filepath.Join(".sqlhistory") // history file written to working dir
initialPrompt := fmt.Sprintf("%s> ", sqlCtx.GetCurrentDatabase())
initialMultilinePrompt := fmt.Sprintf(fmt.Sprintf("%%%ds", len(initialPrompt)), "-> ")

db, branch, _ := getDBBranchFromSession(sqlCtx, qryist)
dirty := false
if branch != "" {
dirty, _ = isDirty(sqlCtx, qryist)
}

initialPrompt, initialMultilinePrompt := formattedPrompts(db, branch, dirty)

rlConf := readline.Config{
Prompt: initialPrompt,
Expand Down Expand Up @@ -764,6 +769,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}

var nextPrompt string
var multiPrompt string
var sqlSch sql.Schema
var rowIter sql.RowIter

Expand All @@ -789,11 +795,14 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}
}

db, ok := getDBFromSession(sqlCtx, qryist)
db, branch, ok := getDBBranchFromSession(sqlCtx, qryist)
if ok {
sqlCtx.SetCurrentDatabase(db)
}
nextPrompt = fmt.Sprintf("%s> ", sqlCtx.GetCurrentDatabase())
if branch != "" {
dirty, _ = isDirty(sqlCtx, qryist)
}
nextPrompt, multiPrompt = formattedPrompts(db, branch, dirty)

return true
}()
Expand All @@ -803,7 +812,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}

shell.SetPrompt(nextPrompt)
shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(nextPrompt)), "-> "))
shell.SetMultiPrompt(multiPrompt)
})

shell.Run()
Expand All @@ -812,30 +821,96 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
return nil
}

// getDBFromSession returns the current database name for the session, handling all the errors along the way by printing
// red error messages to the CLI. If there was an issue getting the db name, the second return value is false.
func getDBFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, ok bool) {
_, resp, err := qryist.Query(sqlCtx, "select database()")
// formattedPrompts returns the prompt and multiline prompt for the current session. If the db is empty, the prompt will
// be "> ", otherwise it will be "db> ". If the branch is empty, the multiline prompt will be "-> ", left padded for
// alignment with the prompt.
func formattedPrompts(db, branch string, dirty bool) (string, string) {
if db == "" {
return "> ", "-> "
}
if branch == "" {
// +2 Allows for the "->" to lineup correctly
multi := fmt.Sprintf(fmt.Sprintf("%%%ds", len(db)+2), "-> ")
cyanDb := color.CyanString(db)
return fmt.Sprintf("%s> ", cyanDb), multi
}

// +3 is for the "/" and "->" to lineup correctly
promptLen := len(db) + len(branch) + 3
dirtyStr := ""
if dirty {
dirtyStr = color.RedString("*")
promptLen += 1
}

multi := fmt.Sprintf(fmt.Sprintf("%%%ds", promptLen), "-> ")

cyanDb := color.CyanString(db)
yellowBr := color.YellowString(branch)
return fmt.Sprintf("%s/%s%s> ", cyanDb, yellowBr, dirtyStr), multi
}

// getDBBranchFromSession returns the current database name and current branch for the session, handling all the errors
// along the way by printing red error messages to the CLI. If there was an issue getting the db name, the ok return
// value will be false and the strings will be empty.
func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, branch string, ok bool) {
_, resp, err := qryist.Query(sqlCtx, "select database() as db, active_branch() as branch")
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session" + err.Error()))
return db, false
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return db, branch, false
}
// Expect single row/single column result with the db name.
// Expect single row result, with two columns: db name, branch name.
row, err := resp.Next(sqlCtx)
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session" + err.Error()))
return db, false
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return db, branch, false
}
if len(row) != 1 {
cli.Println(color.RedString("Failure to get DB Name for session" + err.Error()))
return db, false
if len(row) != 2 {
cli.Println(color.RedString("Runtime error. Invalid column count."))
return db, branch, false
}

if row[1] == nil {
branch = ""
} else {
branch = row[1].(string)
}
if row[0] == nil {
db = ""
} else {
db = row[0].(string)

// It is possible to `use mydb/branch`, and as far as your session is concerned your database is mydb/branch. We
// allow that, but also want to show the user the branch name in the prompt. So we munge the DB in this case.
if strings.HasSuffix(db, "/"+branch) {
db = db[:len(db)-len(branch)-1]
}
}
return db, true

return db, branch, true
}

// isDirty returns true if the workspace is dirty, false otherwise. This function _assumes_ you are on a database
// with a branch. If you are not, you will get an error.
func isDirty(sqlCtx *sql.Context, qryist cli.Queryist) (bool, error) {
_, resp, err := qryist.Query(sqlCtx, "select count(table_name) > 0 as dirty from dolt_Status")

if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return false, err
}
// Expect single row result, with one boolean column.
row, err := resp.Next(sqlCtx)
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return false, err
}
if len(row) != 1 {
cli.Println(color.RedString("Runtime error. Invalid column count."))
return false, fmt.Errorf("invalid column count")
}

return getStrBoolColAsBool(row[0])
}

// Returns a new auto completer with table names, column names, and SQL keywords.
Expand Down
13 changes: 13 additions & 0 deletions go/cmd/dolt/commands/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,19 @@ func getInt64ColAsInt64(col interface{}) (int64, error) {
}
}

// getStringColAsString returns the value of the input as a bool. This is required because depending on if we
// go over the wire or not we may get a string or a bool when we expect a bool.
func getStrBoolColAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return col.(bool), nil
case string:
return strings.ToLower(col.(string)) == "true", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting bool or string", v)
}
}

func getActiveBranchName(sqlCtx *sql.Context, queryEngine cli.Queryist) (string, error) {
query := "SELECT active_branch()"
rows, err := GetRowsForSql(queryEngine, sqlCtx, query)
Expand Down
10 changes: 8 additions & 2 deletions go/libraries/doltcore/sqle/dfunctions/active_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ func NewActiveBranchFunc() sql.Expression {
// Eval implements the Expression interface.
func (ab *ActiveBranchFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
dbName := ctx.GetCurrentDatabase()
if dbName == "" {
// it is possible to have no current database in some contexts.
// When you first connect to a sql server, which has no databases, for example.
return nil, nil
}

dSess := dsess.DSessFromSess(ctx.Session)

ddb, ok := dSess.GetDoltDB(ctx, dbName)

if !ok {
return nil, sql.ErrDatabaseNotFound.New(dbName)
// Not all databases are dolt databases. information_schema and mysql, for example.
return nil, nil
}

currentBranchRef, err := dSess.CWBHeadRef(ctx, dbName)
Expand Down
14 changes: 7 additions & 7 deletions integration-tests/bats/sql-delimiter.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@
set timeout 2
spawn dolt sql
expect {
"doltsql> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT);\r"; }
"> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT);\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
"doltsql> " { send "INSERT INTO test VALUES (0,0);\r"; }
"> " { send "INSERT INTO test VALUES (0,0);\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
"doltsql> " { send "DELIMITER $$\r"; }
"> " { send "DELIMITER $$\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
"doltsql> " { send "INSERT INTO test VALUES (1,1)$$\r"; }
"> " { send "INSERT INTO test VALUES (1,1)$$\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
"doltsql> " { send "delimiter #\r"; }
"> " { send "delimiter #\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
"doltsql> " { send "CREATE TRIGGER tt BEFORE INSERT ON test FOR EACH ROW\r"; }
"> " { send "CREATE TRIGGER tt BEFORE INSERT ON test FOR EACH ROW\r"; }
timeout { exit 1; }
failed { exit 1; }
}
Expand All @@ -53,7 +53,7 @@ expect {
failed { exit 1; }
}
expect {
"doltsql> " { send "DeLiMiTeR ;\r"; }
"> " { send "DeLiMiTeR ;\r"; }
timeout { exit 1; }
failed { exit 1; }
}
Expand Down
50 changes: 50 additions & 0 deletions integration-tests/bats/sql-shell-prompt.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/expect

set timeout 5
spawn dolt sql

expect {
-re "> " { send "create database mydb;\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
-re "> " { send "use mydb;\r"; }
timeout { exit 1; }
failed { exit 1; }
}
expect {
-re ".*mydb.*/.*main.*> " { send "create table tbl (i int);\r"; }
timeout { exit 1; }
failed { exit 1; }
}

# Dirty workspace should show in prompt as a "*" before the ">"
# (all the .* instances here are to account for ansi colors chars.
expect {
-re ".*mydb.*/.*main.*\\*.*> " { send "call dolt_commit('-Am', 'msg');\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect {
-re ".*mydb.*/.*main.*> " { send "call dolt_checkout('-b','other','HEAD');\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect {
-re ".*mydb.*/.*main.*> " { send "use mysql;\r"; }
timeout { exit 1; }
failed { exit 1; }
}

# using a non dolt db should result in a prompt without a slash. The brackets
# are required to get expect to properly parse this regex.
expect {
-re {.*mysql[^\\/]*> } { send "exit;\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect eof
15 changes: 15 additions & 0 deletions integration-tests/bats/sql-shell.bats
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ teardown() {
[[ "$output" =~ "+---------------------" ]] || false
}

# bats test_tags=no_lambda
@test "sql-shell: sql shell prompt updates" {
skiponwindows "Need to install expect and make this script work on windows."
if [ "$SQL_ENGINE" = "remote-engine" ]; then
skip "Presently sql command will not connect to remote server due to lack of lock file where there are not DBs."
fi

# start in an empty directory
rm -rf .dolt
mkdir sql_shell_test
cd sql_shell_test

$BATS_TEST_DIRNAME/sql-shell-prompt.expect
}

# bats test_tags=no_lambda
@test "sql-shell: shell works after failing query" {
skiponwindows "Need to install expect and make this script work on windows."
Expand Down
16 changes: 8 additions & 8 deletions integration-tests/bats/sql-unique-error.expect
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,34 @@ set timeout 2
spawn dolt sql

expect {
"doltsql> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE);\r"; }
"> " { send "CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE);\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect {
"doltsql> " { send "INSERT INTO test VALUES (0,0);\r"; }
"> " { send "INSERT INTO test VALUES (0,0);\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect {
"doltsql> " { send "INSERT INTO test VALUES (1,0);\r"; }
"> " { send "INSERT INTO test VALUES (1,0);\r"; }
timeout { exit 1; }
"UNIQUE" { exp_continue; }
failed { exp_continue; }
}

expect {
"doltsql> " { send "INSERT INTO test VALUES (1,1);\r"; }
"> " { send "INSERT INTO test VALUES (1,1);\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect {
"doltsql> " { send "INSERT INTO test VALUES (2,2);\r"; }
timeout { exit 1; }
failed { exit 1; }
}
"> " { send "INSERT INTO test VALUES (2,2);\r"; }
timeout { exit 1; }
failed { exit 1; }
}

expect eof
Loading