From c0ad196aa398949423d6a48896f639ce39fe95a6 Mon Sep 17 00:00:00 2001 From: Dane Schneider Date: Mon, 6 May 2024 11:13:39 -0700 Subject: [PATCH] Make prompts that shouldn't be optional required across the board, fixes #108 --- app/cli/auth/account.go | 6 +++--- app/cli/auth/org.go | 2 +- app/cli/auth/trial.go | 4 ++-- app/cli/cmd/checkout.go | 2 +- app/cli/cmd/invite.go | 4 ++-- app/cli/cmd/model_packs.go | 2 +- app/cli/cmd/models.go | 20 +++++++++++++------- app/cli/cmd/rename.go | 2 +- app/cli/cmd/set_model.go | 2 +- app/cli/term/prompt.go | 14 ++++++++++++++ app/server/db/git.go | 4 ++++ 11 files changed, 43 insertions(+), 19 deletions(-) diff --git a/app/cli/auth/account.go b/app/cli/auth/account.go index 51ab2888..e7877550 100644 --- a/app/cli/auth/account.go +++ b/app/cli/auth/account.go @@ -149,19 +149,19 @@ func promptSignInNewAccount() error { var email string if selected == SignInCloudOption { - email, err = term.GetUserStringInput("Your email:") + email, err = term.GetRequiredUserStringInput("Your email:") if err != nil { return fmt.Errorf("error prompting email: %v", err) } } else { - host, err = term.GetUserStringInput("Host:") + host, err = term.GetRequiredUserStringInput("Host:") if err != nil { return fmt.Errorf("error prompting host: %v", err) } - email, err = term.GetUserStringInput("Your email:") + email, err = term.GetRequiredUserStringInput("Your email:") if err != nil { return fmt.Errorf("error prompting email: %v", err) diff --git a/app/cli/auth/org.go b/app/cli/auth/org.go index 21f23a0d..b58a19c2 100644 --- a/app/cli/auth/org.go +++ b/app/cli/auth/org.go @@ -59,7 +59,7 @@ func promptNoOrgs() (*shared.Org, error) { } func createOrg() (*shared.Org, error) { - name, err := term.GetUserStringInput("Org name:") + name, err := term.GetRequiredUserStringInput("Org name:") if err != nil { return nil, fmt.Errorf("error prompting org name: %v", err) } diff --git a/app/cli/auth/trial.go b/app/cli/auth/trial.go index 77bad928..c6611564 100644 --- a/app/cli/auth/trial.go +++ b/app/cli/auth/trial.go @@ -9,7 +9,7 @@ import ( ) func ConvertTrial() error { - email, err := term.GetUserStringInput("Your email:") + email, err := term.GetRequiredUserStringInput("Your email:") if err != nil { return fmt.Errorf("error prompting email: %v", err) @@ -31,7 +31,7 @@ func ConvertTrial() error { return fmt.Errorf("error prompting name: %v", err) } - orgName, err := term.GetUserStringInput("Org name:") + orgName, err := term.GetRequiredUserStringInput("Org name:") if err != nil { return fmt.Errorf("error prompting org name: %v", err) diff --git a/app/cli/cmd/checkout.go b/app/cli/cmd/checkout.go index 77be5f3e..10f0b2ef 100644 --- a/app/cli/cmd/checkout.go +++ b/app/cli/cmd/checkout.go @@ -107,7 +107,7 @@ func checkout(cmd *cobra.Command, args []string) { } if selected == OptCreateNewBranch { - branchName, err = term.GetUserStringInput("Branch name") + branchName, err = term.GetRequiredUserStringInput("Branch name") if err != nil { term.OutputErrorAndExit("Error getting branch name: %v", err) return diff --git a/app/cli/cmd/invite.go b/app/cli/cmd/invite.go index ff85e162..db1ce895 100644 --- a/app/cli/cmd/invite.go +++ b/app/cli/cmd/invite.go @@ -45,14 +45,14 @@ func invite(cmd *cobra.Command, args []string) { if email == "" { var err error - email, err = term.GetUserStringInput("Email:") + email, err = term.GetRequiredUserStringInput("Email:") if err != nil { term.OutputErrorAndExit("Failed to get email: %v", err) } } if name == "" { var err error - name, err = term.GetUserStringInput("Name:") + name, err = term.GetRequiredUserStringInput("Name:") if err != nil { term.OutputErrorAndExit("Failed to get name: %v", err) } diff --git a/app/cli/cmd/model_packs.go b/app/cli/cmd/model_packs.go index 4d15bf1b..34970315 100644 --- a/app/cli/cmd/model_packs.go +++ b/app/cli/cmd/model_packs.go @@ -177,7 +177,7 @@ func createModelPack(cmd *cobra.Command, args []string) { mp := &shared.ModelPack{} - name, err := term.GetUserStringInput("Enter model pack name:") + name, err := term.GetRequiredUserStringInput("Enter model pack name:") if err != nil { term.OutputErrorAndExit("Error reading model pack name: %v", err) return diff --git a/app/cli/cmd/models.go b/app/cli/cmd/models.go index 22349661..4e84a1dd 100644 --- a/app/cli/cmd/models.go +++ b/app/cli/cmd/models.go @@ -82,7 +82,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { model.Provider = shared.ModelProvider(provider) if model.Provider == shared.ModelProviderCustom { - customProvider, err := term.GetUserStringInput("Custom provider:") + customProvider, err := term.GetRequiredUserStringInput("Custom provider:") if err != nil { term.OutputErrorAndExit("Error reading custom provider: %v", err) return @@ -91,7 +91,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { } fmt.Println("For model name, be sure to enter the exact, case-sensitive name of the model as it appears in the provider's API docs. Ex: 'gpt-4-turbo', 'meta-llama/Llama-3-70b-chat-hf'") - modelName, err := term.GetUserStringInput("Model name:") + modelName, err := term.GetRequiredUserStringInput("Model name:") if err != nil { term.OutputErrorAndExit("Error reading model name: %v", err) return @@ -107,7 +107,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { model.Description = description if model.Provider == shared.ModelProviderCustom { - baseUrl, err := term.GetUserStringInput("Base URL:") + baseUrl, err := term.GetRequiredUserStringInput("Base URL:") if err != nil { term.OutputErrorAndExit("Error reading base URL: %v", err) return @@ -118,7 +118,13 @@ func createCustomModel(cmd *cobra.Command, args []string) { } apiKeyDefault := shared.ApiKeyByProvider[model.Provider] - apiKeyEnvVar, err := term.GetUserStringInputWithDefault("API key environment variable:", apiKeyDefault) + var apiKeyEnvVar string + if apiKeyDefault == "" { + apiKeyEnvVar, err = term.GetRequiredUserStringInput("API key environment variable:") + } else { + apiKeyEnvVar, err = term.GetUserStringInputWithDefault("API key environment variable:", apiKeyDefault) + } + if err != nil { term.OutputErrorAndExit("Error reading API key environment variable: %v", err) return @@ -127,7 +133,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { fmt.Println("Max Tokens is the total maximum context size of the model.") - maxTokensStr, err := term.GetUserStringInput("Max Tokens:") + maxTokensStr, err := term.GetRequiredUserStringInput("Max Tokens:") if err != nil { term.OutputErrorAndExit("Error reading max tokens: %v", err) return @@ -140,7 +146,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { model.MaxTokens = maxTokens fmt.Println("'Default Max Convo Tokens' is the default maximum size a conversation can reach in the 'planner' role before it is shortened by summarization. For models with 8k context, ~2500 is recommended. For 128k context, ~10000 is recommended.") - maxConvoTokensStr, err := term.GetUserStringInput("Default Max Convo Tokens:") + maxConvoTokensStr, err := term.GetRequiredUserStringInput("Default Max Convo Tokens:") if err != nil { term.OutputErrorAndExit("Error reading max convo tokens: %v", err) return @@ -153,7 +159,7 @@ func createCustomModel(cmd *cobra.Command, args []string) { model.DefaultMaxConvoTokens = maxConvoTokens fmt.Println("'Default Reserved Output Tokens' is the default number of tokens reserved for model output in the 'planner' role. This ensures the model has enough tokens to generate a response. For models with 8k context, ~1000 is recommended. For 128k context, ~4000 is recommended.") - reservedOutputTokensStr, err := term.GetUserStringInput("Default Reserved Output Tokens:") + reservedOutputTokensStr, err := term.GetRequiredUserStringInput("Default Reserved Output Tokens:") if err != nil { term.OutputErrorAndExit("Error reading reserved output tokens: %v", err) return diff --git a/app/cli/cmd/rename.go b/app/cli/cmd/rename.go index dad6b62e..4c45cc11 100644 --- a/app/cli/cmd/rename.go +++ b/app/cli/cmd/rename.go @@ -36,7 +36,7 @@ func rename(cmd *cobra.Command, args []string) { newName = args[0] } else { var err error - newName, err = term.GetUserStringInput("New name:") + newName, err = term.GetRequiredUserStringInput("New name:") if err != nil { term.OutputErrorAndExit("Error reading new name: %v", err) } diff --git a/app/cli/cmd/set_model.go b/app/cli/cmd/set_model.go index 194e2014..479e2844 100644 --- a/app/cli/cmd/set_model.go +++ b/app/cli/cmd/set_model.go @@ -409,7 +409,7 @@ func updateModelSettings(args []string, originalSettings *shared.PlanSettings) * msg += "top-p (0.0 to 1.0)" } var err error - value, err = term.GetUserStringInput(msg) + value, err = term.GetRequiredUserStringInput(msg) if err != nil { if err.Error() == "interrupt" { return nil diff --git a/app/cli/term/prompt.go b/app/cli/term/prompt.go index e2551638..ba98b6d7 100644 --- a/app/cli/term/prompt.go +++ b/app/cli/term/prompt.go @@ -10,6 +10,20 @@ import ( "github.com/fatih/color" ) +func GetRequiredUserStringInput(msg string) (string, error) { + res, err := GetUserStringInput(msg) + if err != nil { + return "", fmt.Errorf("failed to get user input: %s", err) + } + + if res == "" { + color.New(color.Bold, ColorHiRed).Println("🚨 This input is required") + return GetRequiredUserStringInput(msg) + } + + return res, nil +} + func GetUserStringInput(msg string) (string, error) { return GetUserStringInputWithDefault(msg, "") } diff --git a/app/server/db/git.go b/app/server/db/git.go index a04b336e..dcac1018 100644 --- a/app/server/db/git.go +++ b/app/server/db/git.go @@ -338,6 +338,10 @@ func gitRemoveIndexLockFileIfExists(repoDir string) error { if err == nil { if err := os.Remove(lockFilePath); err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("error removing lock file: %v", err) } } else if !os.IsNotExist(err) {