Skip to content

Commit

Permalink
Make prompts that shouldn't be optional required across the board, fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
danenania committed May 6, 2024
1 parent a940d2f commit c0ad196
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 19 deletions.
6 changes: 3 additions & 3 deletions app/cli/auth/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ func promptSignInNewAccount() error {
var email string

if selected == SignInCloudOption {
email, err = term.GetUserStringInput("Your email:")
email, err = term.GetRequiredUserStringInput("Your email:")

if err != nil {
return fmt.Errorf("error prompting email: %v", err)
}
} else {
host, err = term.GetUserStringInput("Host:")
host, err = term.GetRequiredUserStringInput("Host:")

if err != nil {
return fmt.Errorf("error prompting host: %v", err)
}

email, err = term.GetUserStringInput("Your email:")
email, err = term.GetRequiredUserStringInput("Your email:")

if err != nil {
return fmt.Errorf("error prompting email: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion app/cli/auth/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func promptNoOrgs() (*shared.Org, error) {
}

func createOrg() (*shared.Org, error) {
name, err := term.GetUserStringInput("Org name:")
name, err := term.GetRequiredUserStringInput("Org name:")
if err != nil {
return nil, fmt.Errorf("error prompting org name: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions app/cli/auth/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func ConvertTrial() error {
email, err := term.GetUserStringInput("Your email:")
email, err := term.GetRequiredUserStringInput("Your email:")

if err != nil {
return fmt.Errorf("error prompting email: %v", err)
Expand All @@ -31,7 +31,7 @@ func ConvertTrial() error {
return fmt.Errorf("error prompting name: %v", err)
}

orgName, err := term.GetUserStringInput("Org name:")
orgName, err := term.GetRequiredUserStringInput("Org name:")

if err != nil {
return fmt.Errorf("error prompting org name: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion app/cli/cmd/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func checkout(cmd *cobra.Command, args []string) {
}

if selected == OptCreateNewBranch {
branchName, err = term.GetUserStringInput("Branch name")
branchName, err = term.GetRequiredUserStringInput("Branch name")
if err != nil {
term.OutputErrorAndExit("Error getting branch name: %v", err)
return
Expand Down
4 changes: 2 additions & 2 deletions app/cli/cmd/invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ func invite(cmd *cobra.Command, args []string) {

if email == "" {
var err error
email, err = term.GetUserStringInput("Email:")
email, err = term.GetRequiredUserStringInput("Email:")
if err != nil {
term.OutputErrorAndExit("Failed to get email: %v", err)
}
}
if name == "" {
var err error
name, err = term.GetUserStringInput("Name:")
name, err = term.GetRequiredUserStringInput("Name:")
if err != nil {
term.OutputErrorAndExit("Failed to get name: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion app/cli/cmd/model_packs.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func createModelPack(cmd *cobra.Command, args []string) {

mp := &shared.ModelPack{}

name, err := term.GetUserStringInput("Enter model pack name:")
name, err := term.GetRequiredUserStringInput("Enter model pack name:")
if err != nil {
term.OutputErrorAndExit("Error reading model pack name: %v", err)
return
Expand Down
20 changes: 13 additions & 7 deletions app/cli/cmd/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {
model.Provider = shared.ModelProvider(provider)

if model.Provider == shared.ModelProviderCustom {
customProvider, err := term.GetUserStringInput("Custom provider:")
customProvider, err := term.GetRequiredUserStringInput("Custom provider:")
if err != nil {
term.OutputErrorAndExit("Error reading custom provider: %v", err)
return
Expand All @@ -91,7 +91,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {
}

fmt.Println("For model name, be sure to enter the exact, case-sensitive name of the model as it appears in the provider's API docs. Ex: 'gpt-4-turbo', 'meta-llama/Llama-3-70b-chat-hf'")
modelName, err := term.GetUserStringInput("Model name:")
modelName, err := term.GetRequiredUserStringInput("Model name:")
if err != nil {
term.OutputErrorAndExit("Error reading model name: %v", err)
return
Expand All @@ -107,7 +107,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {
model.Description = description

if model.Provider == shared.ModelProviderCustom {
baseUrl, err := term.GetUserStringInput("Base URL:")
baseUrl, err := term.GetRequiredUserStringInput("Base URL:")
if err != nil {
term.OutputErrorAndExit("Error reading base URL: %v", err)
return
Expand All @@ -118,7 +118,13 @@ func createCustomModel(cmd *cobra.Command, args []string) {
}

apiKeyDefault := shared.ApiKeyByProvider[model.Provider]
apiKeyEnvVar, err := term.GetUserStringInputWithDefault("API key environment variable:", apiKeyDefault)
var apiKeyEnvVar string
if apiKeyDefault == "" {
apiKeyEnvVar, err = term.GetRequiredUserStringInput("API key environment variable:")
} else {
apiKeyEnvVar, err = term.GetUserStringInputWithDefault("API key environment variable:", apiKeyDefault)
}

if err != nil {
term.OutputErrorAndExit("Error reading API key environment variable: %v", err)
return
Expand All @@ -127,7 +133,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {

fmt.Println("Max Tokens is the total maximum context size of the model.")

maxTokensStr, err := term.GetUserStringInput("Max Tokens:")
maxTokensStr, err := term.GetRequiredUserStringInput("Max Tokens:")
if err != nil {
term.OutputErrorAndExit("Error reading max tokens: %v", err)
return
Expand All @@ -140,7 +146,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {
model.MaxTokens = maxTokens

fmt.Println("'Default Max Convo Tokens' is the default maximum size a conversation can reach in the 'planner' role before it is shortened by summarization. For models with 8k context, ~2500 is recommended. For 128k context, ~10000 is recommended.")
maxConvoTokensStr, err := term.GetUserStringInput("Default Max Convo Tokens:")
maxConvoTokensStr, err := term.GetRequiredUserStringInput("Default Max Convo Tokens:")
if err != nil {
term.OutputErrorAndExit("Error reading max convo tokens: %v", err)
return
Expand All @@ -153,7 +159,7 @@ func createCustomModel(cmd *cobra.Command, args []string) {
model.DefaultMaxConvoTokens = maxConvoTokens

fmt.Println("'Default Reserved Output Tokens' is the default number of tokens reserved for model output in the 'planner' role. This ensures the model has enough tokens to generate a response. For models with 8k context, ~1000 is recommended. For 128k context, ~4000 is recommended.")
reservedOutputTokensStr, err := term.GetUserStringInput("Default Reserved Output Tokens:")
reservedOutputTokensStr, err := term.GetRequiredUserStringInput("Default Reserved Output Tokens:")
if err != nil {
term.OutputErrorAndExit("Error reading reserved output tokens: %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion app/cli/cmd/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func rename(cmd *cobra.Command, args []string) {
newName = args[0]
} else {
var err error
newName, err = term.GetUserStringInput("New name:")
newName, err = term.GetRequiredUserStringInput("New name:")
if err != nil {
term.OutputErrorAndExit("Error reading new name: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion app/cli/cmd/set_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ func updateModelSettings(args []string, originalSettings *shared.PlanSettings) *
msg += "top-p (0.0 to 1.0)"
}
var err error
value, err = term.GetUserStringInput(msg)
value, err = term.GetRequiredUserStringInput(msg)
if err != nil {
if err.Error() == "interrupt" {
return nil
Expand Down
14 changes: 14 additions & 0 deletions app/cli/term/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ import (
"github.com/fatih/color"
)

func GetRequiredUserStringInput(msg string) (string, error) {
res, err := GetUserStringInput(msg)
if err != nil {
return "", fmt.Errorf("failed to get user input: %s", err)
}

if res == "" {
color.New(color.Bold, ColorHiRed).Println("🚨 This input is required")
return GetRequiredUserStringInput(msg)
}

return res, nil
}

func GetUserStringInput(msg string) (string, error) {
return GetUserStringInputWithDefault(msg, "")
}
Expand Down
4 changes: 4 additions & 0 deletions app/server/db/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ func gitRemoveIndexLockFileIfExists(repoDir string) error {

if err == nil {
if err := os.Remove(lockFilePath); err != nil {
if os.IsNotExist(err) {
return nil
}

return fmt.Errorf("error removing lock file: %v", err)
}
} else if !os.IsNotExist(err) {
Expand Down

0 comments on commit c0ad196

Please sign in to comment.