diff --git a/client/ai/ai.go b/client/ai/ai.go index bf180ef5..f283c698 100644 --- a/client/ai/ai.go +++ b/client/ai/ai.go @@ -6,7 +6,9 @@ import ( "fmt" "os" "os/exec" + "regexp" "runtime" + "strings" "time" "github.com/ddworken/hishtory/client/data" @@ -26,11 +28,61 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num return nil, nil } +func extractFilenames(text string) []string { + if strings.Count(text, "`")%2 != 0 { + return []string{} + } + pattern := "`([^`]*)`" + re := regexp.MustCompile(pattern) + matches := re.FindAllStringSubmatch(text, -1) + filenames := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) > 1 { + filename := match[1] + if len(filename) == 0 || len(filename) > 50 { + return []string{} + } + _, err := os.Stat(filename) + if err != nil { + return []string{} + } + filenames = append(filenames, filename) + } + } + + return filenames +} + +func augmentQuery(ctx context.Context, query string) string { + if !hctx.GetConf(ctx).BetaMode { + return query + } + filenames := extractFilenames(query) + if len(filenames) == 0 { + return query + } + newQuery := "Context:\n" + for _, filename := range filenames { + newQuery += "The file `" + filename + "` has contents like:\n```" + contents, err := os.ReadFile(filename) + if err != nil { + hctx.GetLogger().Warnf("while augmenting OpenAI query=%#v, failed to read the contents of %#v: %v", query, filename, err) + return query + } + lines := strings.Split(string(contents), "\n") + trimmed := lines[:min(5, len(lines))] + newQuery += strings.Join(trimmed, "\n") + newQuery += "\n...```\n\n" + } + newQuery += query + return newQuery +} + func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint { - return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions) + return GetAiSuggestionsViaHishtoryApi(ctx, shellName, augmentQuery(ctx, query), numberCompletions) } else { - suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), os.Getenv("OPENAI_API_MODEL"), numberCompletions) + suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, augmentQuery(ctx, query), shellName, getOsName(), os.Getenv("OPENAI_API_MODEL"), numberCompletions) return suggestions, err } }