diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index f0ed5871..00000000 --- a/cache/cache.go +++ /dev/null @@ -1,294 +0,0 @@ -package cache - -import ( - "context" - "crypto/sha1" - "encoding/hex" - "fmt" - "os" - "runtime" - "time" - - "github.com/numtide/treefmt/stats" - - "github.com/numtide/treefmt/format" - "github.com/numtide/treefmt/walk" - - "github.com/charmbracelet/log" - - "github.com/adrg/xdg" - "github.com/vmihailenco/msgpack/v5" - bolt "go.etcd.io/bbolt" -) - -const ( - pathsBucket = "paths" - formattersBucket = "formatters" -) - -// Entry represents a cache entry, indicating the last size and modified time for a file path. -type Entry struct { - Size int64 - Modified time.Time -} - -var ( - db *bolt.DB - logger *log.Logger - - ReadBatchSize = 1024 * runtime.NumCPU() -) - -// Open creates an instance of bolt.DB for a given treeRoot path. -// If clean is true, Open will delete any existing data in the cache. -// -// The database will be located in `XDG_CACHE_DIR/treefmt/eval-cache/.db`, where is determined by hashing -// the treeRoot path. This associates a given treeRoot with a given instance of the cache. -func Open(treeRoot string, clean bool, formatters map[string]*format.Formatter) (err error) { - logger = log.WithPrefix("cache") - - // determine a unique and consistent db name for the tree root - h := sha1.New() - h.Write([]byte(treeRoot)) - digest := h.Sum(nil) - - name := hex.EncodeToString(digest) - path, err := xdg.CacheFile(fmt.Sprintf("treefmt/eval-cache/%v.db", name)) - if err != nil { - return fmt.Errorf("could not resolve local path for the cache: %w", err) - } - - // attempt to open the db, but timeout after 1 second - db, err = bolt.Open(path, 0o600, &bolt.Options{Timeout: 1 * time.Second}) - if err != nil { - return fmt.Errorf("failed to open cache at %v: %w", path, err) - } - - err = db.Update(func(tx *bolt.Tx) error { - // create bucket for tracking paths - pathsBucket, err := tx.CreateBucketIfNotExists([]byte(pathsBucket)) - if err != nil { - return fmt.Errorf("failed to create paths bucket: %w", err) - } - - // create bucket for tracking formatters - formattersBucket, err := tx.CreateBucketIfNotExists([]byte(formattersBucket)) - if err != nil { - return fmt.Errorf("failed to create formatters bucket: %w", err) - } - - // check for any newly configured or modified formatters - for name, formatter := range formatters { - - stat, err := os.Lstat(formatter.Executable()) - if err != nil { - return fmt.Errorf("failed to stat formatter executable %v: %w", formatter.Executable(), err) - } - - entry, err := getEntry(formattersBucket, name) - if err != nil { - return fmt.Errorf("failed to retrieve cache entry for formatter %v: %w", name, err) - } - - isNew := entry == nil - hasChanged := entry != nil && !(entry.Size == stat.Size() && entry.Modified == stat.ModTime()) - - if isNew { - logger.Debugf("formatter '%s' is new", name) - } else if hasChanged { - logger.Debug("formatter '%s' has changed", - name, - "size", stat.Size(), - "modTime", stat.ModTime(), - "cachedSize", entry.Size, - "cachedModTime", entry.Modified, - ) - } - - // update overall clean flag - clean = clean || isNew || hasChanged - - // record formatters info - entry = &Entry{ - Size: stat.Size(), - Modified: stat.ModTime(), - } - - if err = putEntry(formattersBucket, name, entry); err != nil { - return fmt.Errorf("failed to write cache entry for formatter %v: %w", name, err) - } - } - - // check for any removed formatters - if err = formattersBucket.ForEach(func(key []byte, _ []byte) error { - _, ok := formatters[string(key)] - if !ok { - // remove the formatter entry from the cache - if err = formattersBucket.Delete(key); err != nil { - return fmt.Errorf("failed to remove cache entry for formatter %v: %w", key, err) - } - // indicate a clean is required - clean = true - } - return nil - }); err != nil { - return fmt.Errorf("failed to check cache for removed formatters: %w", err) - } - - if clean { - // remove all path entries - c := pathsBucket.Cursor() - for k, v := c.First(); !(k == nil && v == nil); k, v = c.Next() { - if err = c.Delete(); err != nil { - return fmt.Errorf("failed to remove path entry: %w", err) - } - } - } - - return nil - }) - - return -} - -// Close closes any open instance of the cache. -func Close() error { - if db == nil { - return nil - } - return db.Close() -} - -// getEntry is a helper for reading cache entries from bolt. -func getEntry(bucket *bolt.Bucket, path string) (*Entry, error) { - b := bucket.Get([]byte(path)) - if b != nil { - var cached Entry - if err := msgpack.Unmarshal(b, &cached); err != nil { - return nil, fmt.Errorf("failed to unmarshal cache info for path '%v': %w", path, err) - } - return &cached, nil - } else { - return nil, nil - } -} - -// putEntry is a helper for writing cache entries into bolt. -func putEntry(bucket *bolt.Bucket, path string, entry *Entry) error { - bytes, err := msgpack.Marshal(entry) - if err != nil { - return fmt.Errorf("failed to marshal cache path %v: %w", path, err) - } - - if err = bucket.Put([]byte(path), bytes); err != nil { - return fmt.Errorf("failed to put cache path %v: %w", path, err) - } - return nil -} - -// ChangeSet is used to walk a filesystem, starting at root, and outputting any new or changed paths using pathsCh. -// It determines if a path is new or has changed by comparing against cache entries. -func ChangeSet(ctx context.Context, statz *stats.Stats, walker walk.Walker, filesCh chan<- *walk.File) error { - start := time.Now() - - defer func() { - logger.Debugf("finished generating change set in %v", time.Since(start)) - }() - - var tx *bolt.Tx - var bucket *bolt.Bucket - var processed int - - defer func() { - // close any pending read tx - if tx != nil { - _ = tx.Rollback() - } - }() - - return walker.Walk(ctx, func(file *walk.File, err error) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - if err != nil { - return fmt.Errorf("failed to walk path: %w", err) - } else if file.Info.IsDir() { - // ignore directories - return nil - } - } - - // open a new read tx if there isn't one in progress - // we have to periodically open a new read tx to prevent writes from being blocked - if tx == nil { - tx, err = db.Begin(false) - if err != nil { - return fmt.Errorf("failed to open a new cache read tx: %w", err) - } - bucket = tx.Bucket([]byte(pathsBucket)) - } - - cached, err := getEntry(bucket, file.RelPath) - if err != nil { - return err - } - - changedOrNew := cached == nil || !(cached.Modified == file.Info.ModTime() && cached.Size == file.Info.Size()) - - statz.Add(stats.Traversed, 1) - if !changedOrNew { - // no change - return nil - } - - statz.Add(stats.Emitted, 1) - - // pass on the path - select { - case <-ctx.Done(): - return ctx.Err() - default: - filesCh <- file - } - - // close the current tx if we have reached the batch size - processed += 1 - if processed == ReadBatchSize { - err = tx.Rollback() - tx = nil - return err - } - - return nil - }) -} - -// Update is used to record updated cache information for the specified list of paths. -func Update(files []*walk.File) error { - start := time.Now() - defer func() { - logger.Debugf("finished processing %v paths in %v", len(files), time.Since(start)) - }() - - if len(files) == 0 { - return nil - } - - return db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(pathsBucket)) - - for _, f := range files { - entry := Entry{ - Size: f.Info.Size(), - Modified: f.Info.ModTime(), - } - - if err := putEntry(bucket, f.RelPath, &entry); err != nil { - return err - } - } - - return nil - }) -} diff --git a/cmd/format/format.go b/cmd/format/format.go index 3b7010a9..03f06d22 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -13,9 +13,11 @@ import ( "syscall" "time" + "github.com/numtide/treefmt/walk/cache" + bolt "go.etcd.io/bbolt" + "github.com/charmbracelet/log" "github.com/gobwas/glob" - "github.com/numtide/treefmt/cache" "github.com/numtide/treefmt/config" "github.com/numtide/treefmt/format" "github.com/numtide/treefmt/stats" @@ -89,17 +91,18 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) cfg.Walk = "filesystem" // update paths with temp file - paths[0] = file.Name() + paths[0], err = filepath.Rel(os.TempDir(), file.Name()) + if err != nil { + return fmt.Errorf("failed to get relative path for temp file: %w", err) + } } else { // checks all paths are contained within the tree root - for idx, path := range paths { + for _, path := range paths { rootPath := filepath.Join(cfg.TreeRoot, path) if _, err = os.Stat(rootPath); err != nil { return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot) } - // update the path entry with an absolute path - paths[idx] = filepath.Clean(rootPath) } } @@ -122,13 +125,6 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) // create a prefixed logger log.SetPrefix("format") - // ensure cache is closed on return - defer func() { - if err := cache.Close(); err != nil { - log.Errorf("failed to close cache: %v", err) - } - }() - // compile global exclude globs globalExcludes, err := format.CompileGlobs(cfg.Excludes) if err != nil { @@ -154,12 +150,32 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) formatters[name] = formatter } - // open the cache if configured + var db *bolt.DB + if !cfg.NoCache { - if err = cache.Open(cfg.TreeRoot, cfg.ClearCache, formatters); err != nil { - // if we can't open the cache, we log a warning and fallback to no cache - log.Warnf("failed to open cache: %v", err) - cfg.NoCache = true + // open the db + db, err = cache.Open(cfg.TreeRoot) + if err != nil { + return fmt.Errorf("failed to open cache: %w", err) + } + + // ensure db is closed after we're finished + defer func() { + if err := db.Close(); err != nil { + log.Errorf("failed to close cache: %v", err) + } + }() + + // clear the cache if desired + if cfg.ClearCache { + if err = cache.Clear(db); err != nil { + return fmt.Errorf("failed to clear cache: %w", err) + } + } + + // Compare formatters, clearing paths if they have changed, and recording their latest info in the db + if err = format.CompareFormatters(db, formatters); err != nil { + return fmt.Errorf("failed to compare formatters: %w", err) } } @@ -184,93 +200,58 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) // create a channel for files that have been formatted formattedCh := make(chan *format.Task, cap(filesCh)) - // create a channel for files that have been processed - processedCh := make(chan *format.Task, cap(filesCh)) - // start concurrent processing tasks in reverse order - eg.Go(updateCache(ctx, cfg, statz, processedCh)) - eg.Go(detectFormatted(ctx, cfg, statz, formattedCh, processedCh)) + eg.Go(postProcessing(ctx, cfg, statz, formattedCh)) eg.Go(applyFormatters(ctx, cfg, statz, globalExcludes, formatters, filesCh, formattedCh)) - eg.Go(walkFilesystem(ctx, cfg, statz, paths, filesCh)) - // wait for everything to complete - return eg.Wait() -} + // + walkType, err := walk.TypeString(cfg.Walk) + if err != nil { + return fmt.Errorf("invalid walk type: %w", err) + } -func walkFilesystem( - ctx context.Context, - cfg *config.Config, - statz *stats.Stats, - paths []string, - filesCh chan *walk.File, -) func() error { - return func() error { - // close the files channel when we're done walking the file system - defer close(filesCh) + reader, err := walk.NewReader(walkType, cfg.TreeRoot, paths, db, statz) + if err != nil { + return fmt.Errorf("failed to create walker: %w", err) + } - eg, ctx := errgroup.WithContext(ctx) - pathsCh := make(chan string, BatchSize) + // - // By default, we use the cli arg, but if the stdin flag has been set we force a filesystem walk - // since we will only be processing one file from a temp directory - walkerType, err := walk.TypeString(cfg.Walk) - if err != nil { - return fmt.Errorf("invalid walk type: %w", err) - } + files := make([]*walk.File, BatchSize) + for { + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) - walkPaths := func() error { - defer close(pathsCh) - - var idx int - for idx < len(paths) { - select { - case <-ctx.Done(): - return ctx.Err() - default: - pathsCh <- paths[idx] - idx += 1 - } - } + n, err := reader.Read(ctx, files) - return nil - } + for idx := 0; idx < n; idx++ { + file := files[idx] - if len(paths) > 0 { - eg.Go(walkPaths) - } else { - // no explicit paths to process, so we only need to process root - pathsCh <- cfg.TreeRoot - close(pathsCh) + // check if this file is new or has changed when compared to the cache entry + if file.Cache == nil || file.Cache.HasChanged(file.Info) { + filesCh <- file + statz.Add(stats.Emitted, 1) + } } - // create a filesystem walker - walker, err := walk.New(walkerType, cfg.TreeRoot, pathsCh) - if err != nil { - return fmt.Errorf("failed to create walker: %w", err) - } + cancel() - // if no cache has been configured, or we are processing from stdin, we invoke the walker directly - if cfg.NoCache || cfg.Stdin { - return walker.Walk(ctx, func(file *walk.File, err error) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - statz.Add(stats.Traversed, 1) - statz.Add(stats.Emitted, 1) - filesCh <- file - return nil - } - }) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + log.Errorf("failed to read files: %v", err) + cancel() + break } + } - // otherwise we pass the walker to the cache and have it generate files for processing based on whether or not - // they have been added/changed since the last invocation - if err = cache.ChangeSet(ctx, statz, walker, filesCh); err != nil { - return fmt.Errorf("failed to generate change set: %w", err) - } - return nil + close(filesCh) + + // wait for everything to complete + if err = eg.Wait(); err != nil { + return err } + + return reader.Close() } // applyFormatters @@ -408,35 +389,30 @@ func applyFormatters( } } -func detectFormatted( +func postProcessing( ctx context.Context, cfg *config.Config, statz *stats.Stats, formattedCh chan *format.Task, - processedCh chan *format.Task, ) func() error { return func() error { - defer func() { - // close formatted channel - close(processedCh) - }() - + LOOP: for { select { // detect ctx cancellation case <-ctx.Done(): return ctx.Err() + // take the next task that has been processed case task, ok := <-formattedCh: if !ok { - // channel has been closed, no further files to process - return nil + break LOOP } // check if the file has changed file := task.File - changed, newInfo, err := file.HasChanged() + changed, newInfo, err := file.Stat() if err != nil { return err } @@ -464,59 +440,10 @@ func detectFormatted( file.Info = newInfo } - // mark as processed - processedCh <- task - } - } - } -} - -func updateCache( - ctx context.Context, - cfg *config.Config, - statz *stats.Stats, - processedCh chan *format.Task, -) func() error { - return func() error { - // used to batch updates for more efficient txs - batch := make([]*format.Task, 0, BatchSize) - - // apply a batch - processBatch := func() error { - // pass the batch to the cache for updating - files := make([]*walk.File, len(batch)) - for idx := range batch { - files[idx] = batch[idx].File - } - if err := cache.Update(files); err != nil { - return err - } - batch = batch[:0] - return nil - } - - // if we are processing from stdin that means we are outputting to stdout, no caching involved - // if f.NoCache is set that means either the user explicitly disabled the cache or we failed to open on - if cfg.Stdin || cfg.NoCache { - // do nothing - processBatch = func() error { return nil } - } - - LOOP: - for { - select { - // detect ctx cancellation - case <-ctx.Done(): - return ctx.Err() - // respond to formatted files - case task, ok := <-processedCh: - if !ok { - // channel has been closed, no further files to process - break LOOP + if file.Release != nil { + file.Release() } - file := task.File - if cfg.Stdin { // dump file into stdout f, err := os.Open(file.Path) @@ -529,30 +456,10 @@ func updateCache( if err = os.Remove(f.Name()); err != nil { return fmt.Errorf("failed to remove temp file %s: %w", file.Path, err) } - - continue - } - - // Append to batch and process if we have enough. - // We do not cache any files that were part of a pipeline in which one or more formatters failed. - // This is to ensure those files are re-processed in later invocations after the user has potentially - // resolved the issue, e.g. fixed a config problem. - if len(task.Errors) == 0 { - batch = append(batch, task) - if len(batch) == BatchSize { - if err := processBatch(); err != nil { - return err - } - } } } } - // final flush - if err := processBatch(); err != nil { - return err - } - // if fail on change has been enabled, check that no files were actually formatted, throwing an error if so if cfg.FailOnChange && statz.Value(stats.Formatted) != 0 { return ErrFailOnChange diff --git a/cmd/root_test.go b/cmd/root_test.go index 050450d3..4e6fed8c 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1,4 +1,4 @@ -package cmd +package cmd_test import ( "bufio" @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/numtide/treefmt/cmd" + "github.com/numtide/treefmt/config" "github.com/charmbracelet/log" @@ -58,7 +60,7 @@ func TestOnUnmatched(t *testing.T) { // - "haskell/treefmt.toml" } - _, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "--on-unmatched", "fatal") + _, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--on-unmatched", "fatal") as.ErrorContains(err, fmt.Sprintf("no formatter for path: %s", paths[0])) checkOutput := func(level string, output []byte) { @@ -70,24 +72,24 @@ func TestOnUnmatched(t *testing.T) { var out []byte // default is warn - out, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "-c") + out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "-c") as.NoError(err) checkOutput("WARN", out) - out, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "-c", "--on-unmatched", "warn") + out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "-c", "--on-unmatched", "warn") as.NoError(err) checkOutput("WARN", out) - out, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-u", "error") + out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-u", "error") as.NoError(err) checkOutput("ERRO", out) - out, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-v", "--on-unmatched", "info") + out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-v", "--on-unmatched", "info") as.NoError(err) checkOutput("INFO", out) t.Setenv("TREEFMT_ON_UNMATCHED", "debug") - out, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-vv") + out, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "-c", "-vv") as.NoError(err) checkOutput("DEBU", out) } @@ -105,14 +107,14 @@ func TestCpuProfile(t *testing.T) { as.NoError(os.Chdir(cwd)) }) - _, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "--cpu-profile", "cpu.pprof") + _, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--cpu-profile", "cpu.pprof") as.NoError(err) as.FileExists(filepath.Join(tempDir, "cpu.pprof")) _, err = os.Stat(filepath.Join(tempDir, "cpu.pprof")) as.NoError(err) t.Setenv("TREEFMT_CPU_PROFILE", "env.pprof") - _, _, err = cmd(t, "-C", tempDir, "--allow-missing-formatter") + _, _, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter") as.NoError(err) as.FileExists(filepath.Join(tempDir, "env.pprof")) _, err = os.Stat(filepath.Join(tempDir, "env.pprof")) @@ -133,14 +135,14 @@ func TestAllowMissingFormatter(t *testing.T) { }, }) - _, _, err := cmd(t, "--config-file", configPath, "--tree-root", tempDir, "-vv") + _, _, err := treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "-vv") as.ErrorIs(err, format.ErrCommandNotFound) - _, _, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir, "--allow-missing-formatter") + _, _, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "--allow-missing-formatter") as.NoError(err) t.Setenv("TREEFMT_ALLOW_MISSING_FORMATTER", "true") - _, _, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir) + _, _, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) } @@ -178,32 +180,32 @@ func TestSpecifyingFormatters(t *testing.T) { } setup() - _, statz, err := cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err := treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 3, 3) setup() - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "elm,nix") + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "elm,nix") as.NoError(err) assertStats(t, as, statz, 32, 32, 2, 2) setup() - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "-f", "ruby,nix") + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "-f", "ruby,nix") as.NoError(err) assertStats(t, as, statz, 32, 32, 2, 2) setup() - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "nix") + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "nix") as.NoError(err) assertStats(t, as, statz, 32, 32, 1, 1) // test bad names setup() - _, _, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "foo") + _, _, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--formatters", "foo") as.Errorf(err, "formatter not found in config: foo") t.Setenv("TREEFMT_FORMATTERS", "bar,foo") - _, _, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, _, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.Errorf(err, "formatter not found in config: bar") } @@ -224,7 +226,7 @@ func TestIncludesAndExcludes(t *testing.T) { } test.WriteConfig(t, configPath, cfg) - _, statz, err := cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err := treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) @@ -232,7 +234,7 @@ func TestIncludesAndExcludes(t *testing.T) { cfg.Excludes = []string{"*.nix"} test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 31, 0) @@ -240,7 +242,7 @@ func TestIncludesAndExcludes(t *testing.T) { cfg.Excludes = []string{"*.nix", "*.hs"} test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 25, 0) @@ -250,7 +252,7 @@ func TestIncludesAndExcludes(t *testing.T) { echo.Excludes = []string{"*.py"} test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 23, 0) @@ -258,7 +260,7 @@ func TestIncludesAndExcludes(t *testing.T) { t.Setenv("TREEFMT_FORMATTER_ECHO_EXCLUDES", "*.py,*.go") test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 22, 0) @@ -268,7 +270,7 @@ func TestIncludesAndExcludes(t *testing.T) { echo.Includes = []string{"*.elm"} test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 1, 0) @@ -276,7 +278,7 @@ func TestIncludesAndExcludes(t *testing.T) { t.Setenv("TREEFMT_FORMATTER_ECHO_INCLUDES", "*.elm,*.js") test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 2, 0) } @@ -299,7 +301,7 @@ func TestPrjRootEnvVariable(t *testing.T) { test.WriteConfig(t, configPath, cfg) t.Setenv("PRJ_ROOT", tempDir) - _, statz, err := cmd(t, "--config-file", configPath) + _, statz, err := treefmt(t, "--config-file", configPath) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) } @@ -323,34 +325,34 @@ func TestCache(t *testing.T) { var err error test.WriteConfig(t, configPath, cfg) - _, statz, err := cmd(t, "--config-file", configPath, "--tree-root", tempDir) + _, statz, err := treefmt(t, "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) // clear cache - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir, "-c") + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "-c") as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) // clear cache - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir, "-c") + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "-c") as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir) + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) // no cache - _, statz, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir, "--no-cache") + _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "--no-cache") as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) } @@ -384,13 +386,13 @@ func TestChangeWorkingDirectory(t *testing.T) { // by default, we look for ./treefmt.toml and use the cwd for the tree root // this should fail if the working directory hasn't been changed first - _, statz, err := cmd(t, "-C", tempDir) + _, statz, err := treefmt(t, "-C", tempDir) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) // use env t.Setenv("TREEFMT_WORKING_DIR", tempDir) - _, statz, err = cmd(t, "-c") + _, statz, err = treefmt(t, "-c") as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) } @@ -412,7 +414,7 @@ func TestFailOnChange(t *testing.T) { } test.WriteConfig(t, configPath, cfg) - _, _, err := cmd(t, "--fail-on-change", "--config-file", configPath, "--tree-root", tempDir) + _, _, err := treefmt(t, "--fail-on-change", "--config-file", configPath, "--tree-root", tempDir) as.ErrorIs(err, format2.ErrFailOnChange) // we have second precision mod time tracking @@ -421,7 +423,7 @@ func TestFailOnChange(t *testing.T) { // test with no cache t.Setenv("TREEFMT_FAIL_ON_CHANGE", "true") test.WriteConfig(t, configPath, cfg) - _, _, err = cmd(t, "--config-file", configPath, "--tree-root", tempDir, "--no-cache") + _, _, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir, "--no-cache") as.ErrorIs(err, format2.ErrFailOnChange) } @@ -463,31 +465,31 @@ func TestBustCacheOnFormatterChange(t *testing.T) { test.WriteConfig(t, configPath, cfg) args := []string{"--config-file", configPath, "--tree-root", tempDir} - _, statz, err := cmd(t, args...) + _, statz, err := treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 3, 0) // tweak mod time of elm formatter as.NoError(test.RecreateSymlink(t, binPath+"/"+"elm-format")) - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 3, 0) // check cache is working - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) // tweak mod time of python formatter as.NoError(test.RecreateSymlink(t, binPath+"/"+"black")) - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 3, 0) // check cache is working - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) @@ -499,12 +501,12 @@ func TestBustCacheOnFormatterChange(t *testing.T) { } test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 4, 0) // check cache is working - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) @@ -512,12 +514,12 @@ func TestBustCacheOnFormatterChange(t *testing.T) { delete(cfg.FormatterConfigs, "python") test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 2, 0) // check cache is working - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) @@ -525,12 +527,12 @@ func TestBustCacheOnFormatterChange(t *testing.T) { delete(cfg.FormatterConfigs, "elm") test.WriteConfig(t, configPath, cfg) - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 32, 1, 0) // check cache is working - _, statz, err = cmd(t, args...) + _, statz, err = treefmt(t, args...) as.NoError(err) assertStats(t, as, statz, 32, 0, 0, 0) } @@ -568,7 +570,7 @@ func TestGitWorktree(t *testing.T) { as.NoError(err, "failed to get git worktree") run := func(traversed int32, emitted int32, matched int32, formatted int32) { - _, statz, err := cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir) + _, statz, err := treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir) as.NoError(err) assertStats(t, as, statz, traversed, emitted, matched, formatted) } @@ -590,7 +592,7 @@ func TestGitWorktree(t *testing.T) { run(28, 28, 28, 0) // walk with filesystem instead of git - _, statz, err := cmd(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--walk", "filesystem") + _, statz, err := treefmt(t, "-c", "--config-file", configPath, "--tree-root", tempDir, "--walk", "filesystem") as.NoError(err) assertStats(t, as, statz, 60, 60, 60, 0) @@ -604,31 +606,31 @@ func TestGitWorktree(t *testing.T) { }) // format specific sub paths - _, statz, err = cmd(t, "-C", tempDir, "-c", "go", "-vv") + _, statz, err = treefmt(t, "-C", tempDir, "-c", "go", "-vv") as.NoError(err) assertStats(t, as, statz, 2, 2, 2, 0) - _, statz, err = cmd(t, "-C", tempDir, "-c", "go", "haskell") + _, statz, err = treefmt(t, "-C", tempDir, "-c", "go", "haskell") as.NoError(err) assertStats(t, as, statz, 9, 9, 9, 0) - _, statz, err = cmd(t, "-C", tempDir, "-c", "go", "haskell", "ruby") + _, statz, err = treefmt(t, "-C", tempDir, "-c", "go", "haskell", "ruby") as.NoError(err) assertStats(t, as, statz, 10, 10, 10, 0) // try with a bad path - _, _, err = cmd(t, "-C", tempDir, "-c", "haskell", "foo") + _, _, err = treefmt(t, "-C", tempDir, "-c", "haskell", "foo") as.ErrorContains(err, "path foo not found within the tree root") // try with a path not in the git index, e.g. it is skipped _, err = os.Create(filepath.Join(tempDir, "foo.txt")) as.NoError(err) - _, statz, err = cmd(t, "-C", tempDir, "-c", "haskell", "foo.txt") + _, statz, err = treefmt(t, "-C", tempDir, "-c", "haskell", "foo.txt") as.NoError(err) assertStats(t, as, statz, 7, 7, 7, 0) - _, statz, err = cmd(t, "-C", tempDir, "-c", "foo.txt") + _, statz, err = treefmt(t, "-C", tempDir, "-c", "foo.txt") as.NoError(err) assertStats(t, as, statz, 0, 0, 0, 0) } @@ -664,22 +666,22 @@ func TestPathsArg(t *testing.T) { test.WriteConfig(t, configPath, cfg) // without any path args - _, statz, err := cmd(t) + _, statz, err := treefmt(t) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) // specify some explicit paths - _, statz, err = cmd(t, "-c", "elm/elm.json", "haskell/Nested/Foo.hs") + _, statz, err = treefmt(t, "-c", "elm/elm.json", "haskell/Nested/Foo.hs") as.NoError(err) assertStats(t, as, statz, 2, 2, 2, 0) // specify a bad path - _, _, err = cmd(t, "-c", "elm/elm.json", "haskell/Nested/Bar.hs") + _, _, err = treefmt(t, "-c", "elm/elm.json", "haskell/Nested/Bar.hs") as.ErrorContains(err, "path haskell/Nested/Bar.hs not found within the tree root") // specify a path outside the tree root externalPath := filepath.Join(cwd, "go.mod") - _, _, err = cmd(t, "-c", externalPath) + _, _, err = treefmt(t, "-c", externalPath) as.ErrorContains(err, fmt.Sprintf("path %s not found within the tree root", externalPath)) } @@ -706,7 +708,7 @@ func TestStdin(t *testing.T) { contents := `{ foo, ... }: "hello"` os.Stdin = test.TempFile(t, "", "stdin", &contents) // we get an error about the missing filename parameter. - out, _, err := cmd(t, "-C", tempDir, "--allow-missing-formatter", "--stdin") + out, _, err := treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--stdin") as.EqualError(err, "exactly one path should be specified when using the --stdin flag") as.Equal("", string(out)) @@ -714,7 +716,7 @@ func TestStdin(t *testing.T) { contents = `{ foo, ... }: "hello"` os.Stdin = test.TempFile(t, "", "stdin", &contents) - out, statz, err := cmd(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.nix") + out, statz, err := treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.nix") as.NoError(err) assertStats(t, as, statz, 1, 1, 1, 1) @@ -731,7 +733,7 @@ func TestStdin(t *testing.T) { ` os.Stdin = test.TempFile(t, "", "stdin", &contents) - out, statz, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.md") + out, statz, err = treefmt(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.md") as.NoError(err) assertStats(t, as, statz, 1, 1, 1, 1) @@ -780,7 +782,7 @@ func TestDeterministicOrderingInPipeline(t *testing.T) { }, }, }) - _, _, err = cmd(t, "-C", tempDir) + _, _, err = treefmt(t, "-C", tempDir) as.NoError(err) matcher := regexp.MustCompile("^fmt-(.*)") @@ -845,17 +847,17 @@ func TestRunInSubdir(t *testing.T) { test.WriteConfig(t, configPath, cfg) // without any path args, should reformat the whole tree - _, statz, err := cmd(t) + _, statz, err := treefmt(t) as.NoError(err) assertStats(t, as, statz, 32, 32, 32, 0) // specify some explicit paths, relative to the tree root - _, statz, err = cmd(t, "-c", "elm/elm.json", "haskell/Nested/Foo.hs") + _, statz, err = treefmt(t, "-c", "elm/elm.json", "haskell/Nested/Foo.hs") as.NoError(err) assertStats(t, as, statz, 2, 2, 2, 0) } -func cmd(t *testing.T, args ...string) ([]byte, *stats.Stats, error) { +func treefmt(t *testing.T, args ...string) ([]byte, *stats.Stats, error) { t.Helper() tempDir := t.TempDir() @@ -879,7 +881,7 @@ func cmd(t *testing.T, args ...string) ([]byte, *stats.Stats, error) { }() // run the command - root, statz := NewRoot() + root, statz := cmd.NewRoot() if args == nil { // we must pass an empty array otherwise cobra with use os.Args[1:] diff --git a/config/config_test.go b/config/config_test.go index ecbfd3cd..78ee49eb 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,4 +1,4 @@ -package config +package config_test import ( "bufio" @@ -8,6 +8,8 @@ import ( "path/filepath" "testing" + "github.com/numtide/treefmt/config" + "github.com/BurntSushi/toml" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -17,7 +19,7 @@ import ( func newViper(t *testing.T) (*viper.Viper, *pflag.FlagSet) { t.Helper() - v, err := NewViper() + v, err := config.NewViper() if err != nil { t.Fatal(err) } @@ -26,7 +28,7 @@ func newViper(t *testing.T) (*viper.Viper, *pflag.FlagSet) { v.SetConfigFile(filepath.Join(tempDir, "treefmt.toml")) flags := pflag.NewFlagSet("test", pflag.ContinueOnError) - SetFlags(flags) + config.SetFlags(flags) if err := v.BindPFlags(flags); err != nil { t.Fatal(err) @@ -34,7 +36,7 @@ func newViper(t *testing.T) (*viper.Viper, *pflag.FlagSet) { return v, flags } -func readValue(t *testing.T, v *viper.Viper, cfg *Config, test func(*Config)) { +func readValue(t *testing.T, v *viper.Viper, cfg *config.Config, test func(*config.Config)) { t.Helper() // serialise the config and read it into viper @@ -47,7 +49,7 @@ func readValue(t *testing.T, v *viper.Viper, cfg *Config, test func(*Config)) { } // - decodedCfg, err := FromViper(v) + decodedCfg, err := config.FromViper(v) if err != nil { t.Fatal(fmt.Errorf("failed to unmarshal config from viper: %w", err)) } @@ -58,11 +60,11 @@ func readValue(t *testing.T, v *viper.Viper, cfg *Config, test func(*Config)) { func TestAllowMissingFormatter(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected bool) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.AllowMissingFormatter) }) } @@ -86,11 +88,11 @@ func TestAllowMissingFormatter(t *testing.T) { func TestCI(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValues := func(ci bool, noCache bool, failOnChange bool, verbosity uint8) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(ci, cfg.CI) as.Equal(noCache, cfg.NoCache) as.Equal(failOnChange, cfg.FailOnChange) @@ -122,11 +124,11 @@ func TestCI(t *testing.T) { func TestClearCache(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected bool) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.ClearCache) }) } @@ -151,11 +153,11 @@ func TestClearCache(t *testing.T) { func TestCpuProfile(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.CpuProfile) }) } @@ -179,11 +181,11 @@ func TestCpuProfile(t *testing.T) { func TestExcludes(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected []string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.Excludes) }) } @@ -212,11 +214,11 @@ func TestExcludes(t *testing.T) { func TestFailOnChange(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected bool) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.FailOnChange) }) } @@ -240,11 +242,11 @@ func TestFailOnChange(t *testing.T) { func TestFormatters(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected []string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.Formatters) }) } @@ -253,7 +255,7 @@ func TestFormatters(t *testing.T) { checkValue([]string{}) // set config value - cfg.FormatterConfigs = map[string]*Formatter{ + cfg.FormatterConfigs = map[string]*config.Formatter{ "echo": { Command: "echo", }, @@ -278,18 +280,18 @@ func TestFormatters(t *testing.T) { // bad formatter name as.NoError(flags.Set("formatters", "foo,echo,date")) - _, err := FromViper(v) + _, err := config.FromViper(v) as.ErrorContains(err, "formatter foo not found in config") } func TestNoCache(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected bool) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.NoCache) }) } @@ -314,11 +316,11 @@ func TestNoCache(t *testing.T) { func TestOnUnmatched(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.OnUnmatched) }) } @@ -342,11 +344,11 @@ func TestOnUnmatched(t *testing.T) { func TestTreeRoot(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.TreeRoot) }) } @@ -371,7 +373,7 @@ func TestTreeRoot(t *testing.T) { func TestTreeRootFile(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) // create a directory structure with config files at various levels @@ -383,7 +385,7 @@ func TestTreeRootFile(t *testing.T) { as.NoError(os.WriteFile(filepath.Join(tempDir, ".git", "config"), []byte{}, 0o644)) checkValue := func(treeRoot string, treeRootFile string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(treeRoot, cfg.TreeRoot) as.Equal(treeRootFile, cfg.TreeRootFile) }) @@ -415,11 +417,11 @@ func TestTreeRootFile(t *testing.T) { func TestVerbosity(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, _ := newViper(t) checkValue := func(expected uint8) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.Verbose) }) } @@ -444,11 +446,11 @@ func TestVerbosity(t *testing.T) { func TestWalk(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.Walk) }) } @@ -472,11 +474,11 @@ func TestWalk(t *testing.T) { func TestWorkingDirectory(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValue := func(expected string) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(expected, cfg.WorkingDirectory) }) } @@ -507,11 +509,11 @@ func TestWorkingDirectory(t *testing.T) { func TestStdin(t *testing.T) { as := require.New(t) - cfg := &Config{} + cfg := &config.Config{} v, flags := newViper(t) checkValues := func(stdin bool) { - readValue(t, v, cfg, func(cfg *Config) { + readValue(t, v, cfg, func(cfg *config.Config) { as.Equal(stdin, cfg.Stdin) }) } @@ -540,7 +542,7 @@ func TestSampleConfigFile(t *testing.T) { v.SetConfigFile("../test/examples/treefmt.toml") as.NoError(v.ReadInConfig(), "failed to read config file") - cfg, err := FromViper(v) + cfg, err := config.FromViper(v) as.NoError(err, "failed to unmarshal config from viper") as.NotNil(cfg) diff --git a/format/cache.go b/format/cache.go new file mode 100644 index 00000000..0bd59cf9 --- /dev/null +++ b/format/cache.go @@ -0,0 +1,93 @@ +package format + +import ( + "fmt" + "os" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/walk/cache" + bolt "go.etcd.io/bbolt" +) + +func CompareFormatters(db *bolt.DB, formatters map[string]*Formatter) error { + return db.Update(func(tx *bolt.Tx) error { + clearPaths := false + + pathsBucket, err := cache.BucketPaths(tx) + if err != nil { + return fmt.Errorf("failed to get paths bucket from cache: %w", err) + } + + formattersBucket, err := cache.BucketFormatters(tx) + if err != nil { + return fmt.Errorf("failed to get formatters bucket from cache: %w", err) + } + + // check for any newly configured or modified formatters + for name, formatter := range formatters { + + stat, err := os.Lstat(formatter.Executable()) + if err != nil { + return fmt.Errorf("failed to stat formatter executable %v: %w", formatter.Executable(), err) + } + + entry, err := formattersBucket.Get(name) + if err != nil { + return fmt.Errorf("failed to retrieve cache entry for formatter %v: %w", name, err) + } + + isNew := entry == nil + hasChanged := entry != nil && !(entry.Size == stat.Size() && entry.Modified == stat.ModTime()) + + if isNew { + log.Debugf("formatter '%s' is new", name) + } else if hasChanged { + log.Debug("formatter '%s' has changed", + name, + "size", stat.Size(), + "modTime", stat.ModTime(), + "cachedSize", entry.Size, + "cachedModTime", entry.Modified, + ) + } + + // update overall flag + clearPaths = clearPaths || isNew || hasChanged + + // record formatters info + entry = &cache.Entry{ + Size: stat.Size(), + Modified: stat.ModTime(), + } + + if err = formattersBucket.Put(name, entry); err != nil { + return fmt.Errorf("failed to write cache entry for formatter %v: %w", name, err) + } + } + + // check for any removed formatters + if err = formattersBucket.ForEach(func(key string, _ *cache.Entry) error { + _, ok := formatters[key] + if !ok { + // remove the formatter entry from the cache + if err = formattersBucket.Delete(key); err != nil { + return fmt.Errorf("failed to remove cache entry for formatter %v: %w", key, err) + } + // indicate a clean is required + clearPaths = true + } + return nil + }); err != nil { + return fmt.Errorf("failed to check cache for removed formatters: %w", err) + } + + if clearPaths { + // remove all path entries + if err := pathsBucket.DeleteAll(); err != nil { + return fmt.Errorf("failed to remove all path entries from cache: %w", err) + } + } + + return nil + }) +} diff --git a/format/formatter.go b/format/formatter.go index a77d00d7..d7ebd8ec 100644 --- a/format/formatter.go +++ b/format/formatter.go @@ -8,9 +8,10 @@ import ( "os/exec" "time" - "github.com/numtide/treefmt/config" "github.com/numtide/treefmt/walk" + "github.com/numtide/treefmt/config" + "github.com/charmbracelet/log" "github.com/gobwas/glob" "mvdan.cc/sh/v3/expand" diff --git a/walk/cache/bucket.go b/walk/cache/bucket.go new file mode 100644 index 00000000..824514b4 --- /dev/null +++ b/walk/cache/bucket.go @@ -0,0 +1,88 @@ +package cache + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + bolt "go.etcd.io/bbolt" +) + +const ( + bucketPaths = "paths" + bucketFormatters = "formatters" +) + +type Bucket[V any] struct { + bucket *bolt.Bucket +} + +func (b *Bucket[V]) Size() int { + return b.bucket.Stats().KeyN +} + +func (b *Bucket[V]) Get(key string) (*V, error) { + bytes := b.bucket.Get([]byte(key)) + if bytes == nil { + return nil, nil + } + var value V + if err := msgpack.Unmarshal(bytes, &value); err != nil { + return nil, fmt.Errorf("failed to unmarshal cache entry for key '%v': %w", key, err) + } + return &value, nil +} + +func (b *Bucket[V]) Put(key string, value *V) error { + if bytes, err := msgpack.Marshal(value); err != nil { + return fmt.Errorf("failed to marshal cache entry for key %v: %w", key, err) + } else if err = b.bucket.Put([]byte(key), bytes); err != nil { + return fmt.Errorf("failed to put cache entry for key %v: %w", key, err) + } + return nil +} + +func (b *Bucket[V]) Delete(key string) error { + return b.bucket.Delete([]byte(key)) +} + +func (b *Bucket[V]) DeleteAll() error { + c := b.bucket.Cursor() + for k, v := c.First(); !(k == nil && v == nil); k, v = c.Next() { + if err := c.Delete(); err != nil { + return fmt.Errorf("failed to remove cache entry for key %s: %w", string(k), err) + } + } + return nil +} + +func (b *Bucket[V]) ForEach(f func(string, *V) error) error { + return b.bucket.ForEach(func(key, bytes []byte) error { + var value V + if err := msgpack.Unmarshal(bytes, &value); err != nil { + return fmt.Errorf("failed to unmarshal cache entry for key '%v': %w", key, err) + } + return f(string(key), &value) + }) +} + +func BucketPaths(tx *bolt.Tx) (*Bucket[Entry], error) { + return cacheBucket(bucketPaths, tx) +} + +func BucketFormatters(tx *bolt.Tx) (*Bucket[Entry], error) { + return cacheBucket(bucketFormatters, tx) +} + +func cacheBucket(name string, tx *bolt.Tx) (*Bucket[Entry], error) { + var b *bolt.Bucket + var err error + if tx.Writable() { + b, err = tx.CreateBucketIfNotExists([]byte(name)) + } else { + b = tx.Bucket([]byte(name)) + } + if err != nil { + return nil, fmt.Errorf("failed to get/create bucket %s: %w", bucketPaths, err) + } + return &Bucket[Entry]{b}, nil +} diff --git a/walk/cache/cache.go b/walk/cache/cache.go new file mode 100644 index 00000000..47dca987 --- /dev/null +++ b/walk/cache/cache.go @@ -0,0 +1,65 @@ +package cache + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "io/fs" + "time" + + "github.com/adrg/xdg" + bolt "go.etcd.io/bbolt" +) + +type Entry struct { + Size int64 + Modified time.Time +} + +func (e *Entry) HasChanged(info fs.FileInfo) bool { + return !(e.Modified == info.ModTime() && e.Size == info.Size()) +} + +func Open(root string) (*bolt.DB, error) { + var err error + var path string + + // Otherwise, the database will be located in `XDG_CACHE_DIR/treefmt/eval-cache/.db`, where is + // determined by hashing the treeRoot path. + // This associates a given treeRoot with a given instance of the cache. + digest := sha1.Sum([]byte(root)) + + name := hex.EncodeToString(digest[:]) + if path, err = xdg.CacheFile(fmt.Sprintf("treefmt/eval-cache/%v.db", name)); err != nil { + return nil, fmt.Errorf("could not resolve local path for the cache: %w", err) + } + + // open db + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 1 * time.Second}) + if err != nil { + return nil, err + } + + return db, nil +} + +func EnsureBuckets(db *bolt.DB) error { + // force creation of buckets if they don't already exist + return db.Update(func(tx *bolt.Tx) error { + if _, err := BucketPaths(tx); err != nil { + return err + } + _, err := BucketFormatters(tx) + return err + }) +} + +func Clear(db *bolt.DB) error { + return db.Update(func(tx *bolt.Tx) error { + bucket, err := BucketPaths(tx) + if err != nil { + return fmt.Errorf("failed to get paths bucket: %w", err) + } + return bucket.DeleteAll() + }) +} diff --git a/walk/cached.go b/walk/cached.go new file mode 100644 index 00000000..eff946ee --- /dev/null +++ b/walk/cached.go @@ -0,0 +1,147 @@ +package walk + +import ( + "context" + "errors" + "fmt" + "io" + "runtime" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/walk/cache" + bolt "go.etcd.io/bbolt" + "golang.org/x/sync/errgroup" +) + +// CachedReader reads files from a delegate Reader, appending a cache Entry on read (if on exists) and updating the +// cache after the file has been processed. +type CachedReader struct { + db *bolt.DB + log *log.Logger + batchSize int + + // delegate is a Reader instance that performs the actual reading operations for the CachedReader. + delegate Reader + + eg *errgroup.Group + // releaseCh contains files which have been released after processing and can be updated in the cache. + releaseCh chan *File +} + +// process updates cached file entries by batching file updates and flushing them to the database periodically. +func (c *CachedReader) process() error { + var batch []*File + + flush := func() error { + // check for an empty batch + if len(batch) == 0 { + return nil + } + + return c.db.Update(func(tx *bolt.Tx) error { + // get the paths bucket + bucket, err := cache.BucketPaths(tx) + if err != nil { + return fmt.Errorf("failed to get bucket: %w", err) + } + + // for each file in the batch, add a new cache entry with update size and mod time. + for _, file := range batch { + entry := &cache.Entry{ + Size: file.Info.Size(), + Modified: file.Info.ModTime(), + } + if err = bucket.Put(file.RelPath, entry); err != nil { + return fmt.Errorf("failed to put entry for path %s: %w", file.RelPath, err) + } + } + return nil + }) + } + + for file := range c.releaseCh { + batch = append(batch, file) + if len(batch) == c.batchSize { + if err := flush(); err != nil { + return err + } + batch = batch[:0] + } + } + + // flush final partial batch + return flush() +} + +func (c *CachedReader) Read(ctx context.Context, files []*File) (n int, err error) { + err = c.db.View(func(tx *bolt.Tx) error { + // get paths bucket + bucket, err := cache.BucketPaths(tx) + if err != nil { + return fmt.Errorf("failed to get bucket: %w", err) + } + + // perform a read on the underlying reader + n, err = c.delegate.Read(ctx, files) + c.log.Debugf("read %d files from delegate", n) + + for i := 0; i < n; i++ { + file := files[i] + + // lookup cache entry and append to the file + file.Cache, err = bucket.Get(file.RelPath) + if err != nil { + return err + } + + // set a release function which inserts this file into the release channel for updating + file.Release = func() { + c.releaseCh <- file + } + } + + if errors.Is(err, io.EOF) { + return err + } else if err != nil { + return fmt.Errorf("failed to read files from delegate: %w", err) + } + + return nil + }) + + return n, err +} + +// Close waits for any processing to complete. +func (c *CachedReader) Close() error { + // close the release channel + close(c.releaseCh) + + // wait for any pending releases to be processed + return c.eg.Wait() +} + +// NewCachedReader creates a cache Reader instance, backed by a bolt DB and delegating reads to delegate. +func NewCachedReader(db *bolt.DB, batchSize int, delegate Reader) (*CachedReader, error) { + // force the creation of the necessary buckets if we're dealing with an empty db + if err := cache.EnsureBuckets(db); err != nil { + return nil, fmt.Errorf("failed to create cache buckets: %w", err) + } + + // create an error group for managing the processing loop + eg := &errgroup.Group{} + + r := &CachedReader{ + db: db, + batchSize: batchSize, + delegate: delegate, + log: log.WithPrefix("walk[cache]"), + eg: eg, + releaseCh: make(chan *File, batchSize*runtime.NumCPU()), + } + + // start the processing loop + eg.Go(r.process) + + return r, nil +} diff --git a/walk/cached_test.go b/walk/cached_test.go new file mode 100644 index 00000000..4f00854c --- /dev/null +++ b/walk/cached_test.go @@ -0,0 +1,126 @@ +package walk_test + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "testing" + "time" + + "github.com/numtide/treefmt/stats" + "github.com/numtide/treefmt/test" + "github.com/numtide/treefmt/walk" + "github.com/numtide/treefmt/walk/cache" + "github.com/stretchr/testify/require" +) + +func TestCachedReader(t *testing.T) { + as := require.New(t) + + batchSize := 1024 + tempDir := test.TempExamples(t) + + readAll := func(paths []string) (totalCount, newCount, changeCount int, statz stats.Stats) { + statz = stats.New() + + db, err := cache.Open(tempDir) + as.NoError(err) + defer db.Close() + + delegate := walk.NewFilesystemReader(tempDir, paths, &statz, batchSize) + reader, err := walk.NewCachedReader(db, batchSize, delegate) + as.NoError(err) + + for { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + + files := make([]*walk.File, 8) + n, err := reader.Read(ctx, files) + + totalCount += n + + for idx := 0; idx < n; idx++ { + file := files[idx] + + if file.Cache == nil { + newCount++ + } else if file.Cache.HasChanged(file.Info) { + changeCount++ + } + file.Release() + } + + cancel() + + if errors.Is(err, io.EOF) { + break + } + } + + as.NoError(reader.Close()) + + return totalCount, newCount, changeCount, statz + } + + totalCount, newCount, changeCount, _ := readAll([]string{"."}) + as.Equal(32, totalCount) + as.Equal(32, newCount) + as.Equal(0, changeCount) + + // read again, should be no changes + totalCount, newCount, changeCount, _ = readAll([]string{"."}) + as.Equal(32, totalCount) + as.Equal(0, newCount) + as.Equal(0, changeCount) + + // change mod times on some files and try again + // we subtract a second to account for the 1 second granularity of modtime according to POSIX + modTime := time.Now().Add(-1 * time.Second) + + as.NoError(os.Chtimes(filepath.Join(tempDir, "treefmt.toml"), time.Now(), modTime)) + as.NoError(os.Chtimes(filepath.Join(tempDir, "shell/foo.sh"), time.Now(), modTime)) + as.NoError(os.Chtimes(filepath.Join(tempDir, "haskell/Nested/Foo.hs"), time.Now(), modTime)) + + totalCount, newCount, changeCount, _ = readAll([]string{"."}) + as.Equal(32, totalCount) + as.Equal(0, newCount) + as.Equal(3, changeCount) + + // create some files and try again + _, err := os.Create(filepath.Join(tempDir, "new.txt")) + as.NoError(err) + + _, err = os.Create(filepath.Join(tempDir, "fizz.go")) + as.NoError(err) + + totalCount, newCount, changeCount, _ = readAll([]string{"."}) + as.Equal(34, totalCount) + as.Equal(2, newCount) + as.Equal(0, changeCount) + + // modify some files + f, err := os.OpenFile(filepath.Join(tempDir, "new.txt"), os.O_WRONLY, 0o644) + as.NoError(err) + _, err = f.Write([]byte("foo")) + as.NoError(err) + as.NoError(f.Close()) + + f, err = os.OpenFile(filepath.Join(tempDir, "fizz.go"), os.O_WRONLY, 0o644) + as.NoError(err) + _, err = f.Write([]byte("bla")) + as.NoError(err) + as.NoError(f.Close()) + + totalCount, newCount, changeCount, _ = readAll([]string{"."}) + as.Equal(34, totalCount) + as.Equal(0, newCount) + as.Equal(2, changeCount) + + // read some paths within the root + totalCount, newCount, changeCount, _ = readAll([]string{"go", "elm/src", "haskell"}) + as.Equal(10, totalCount) + as.Equal(0, newCount) + as.Equal(0, changeCount) +} diff --git a/walk/filesystem.go b/walk/filesystem.go index 55973ced..f6c441c2 100644 --- a/walk/filesystem.go +++ b/walk/filesystem.go @@ -3,29 +3,42 @@ package walk import ( "context" "fmt" + "io" "io/fs" "os" "path/filepath" + "runtime" + "strings" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/stats" + "golang.org/x/sync/errgroup" ) -type filesystemWalker struct { - root string - pathsCh chan string - relPathOffset int -} +// FilesystemReader traverses and reads files from a specified root directory and its subdirectories. +type FilesystemReader struct { + log *log.Logger + root string + paths []string + batchSize int -func (f filesystemWalker) Root() string { - return f.root -} + eg *errgroup.Group -func (f filesystemWalker) relPath(path string) (string, error) { - return filepath.Rel(f.root, path) + stats *stats.Stats + filesCh chan *File } -func (f filesystemWalker) Walk(_ context.Context, fn WalkFunc) error { - walkFn := func(path string, info fs.FileInfo, _ error) error { - if info == nil { - return fmt.Errorf("no such file or directory '%s'", path) +// process traverses the filesystem based on the specified paths, queuing files for the next read. +func (f *FilesystemReader) process() error { + // ensure filesCh is closed on return + defer func() { + close(f.filesCh) + }() + + walkFn := func(path string, info fs.FileInfo, err error) error { + // return errors immediately + if err != nil { + return err } // ignore directories and symlinks @@ -33,20 +46,38 @@ func (f filesystemWalker) Walk(_ context.Context, fn WalkFunc) error { return nil } - relPath, err := f.relPath(path) + // determine a path relative to the root + relPath, err := filepath.Rel(f.root, path) if err != nil { return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) } + // create a new file and pass to the files channel file := File{ Path: path, RelPath: relPath, Info: info, } - return fn(&file, err) + + f.filesCh <- &file + + f.log.Debugf("file queued %s", file.RelPath) + + return nil } - for path := range f.pathsCh { + // walk each path specified + for idx := range f.paths { + // f.paths are relative to the root, so we create a fully qualified version + // we also clean the path up in case there are any ../../ components etc. + path := filepath.Clean(filepath.Join(f.root, f.paths[idx])) + + // ensure the path is within the root + if !strings.HasPrefix(path, f.root) { + return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root) + } + + // walk the path if err := filepath.Walk(path, walkFn); err != nil { return err } @@ -55,10 +86,75 @@ func (f filesystemWalker) Walk(_ context.Context, fn WalkFunc) error { return nil } -func NewFilesystem(root string, paths chan string) (Walker, error) { - return filesystemWalker{ - root: root, - pathsCh: paths, - relPathOffset: len(root) + 1, - }, nil +// Read populates the provided files array with as many files are available until the provided context is cancelled. +// You must ensure to pass a context with a timeout otherwise this will block until files is full. +func (f *FilesystemReader) Read(ctx context.Context, files []*File) (n int, err error) { + idx := 0 + +LOOP: + // fill the files array up to it's length + for idx < len(files) { + select { + + // exit early if the context was cancelled + case <-ctx.Done(): + return idx, ctx.Err() + + // read the next file from the files channel + case file, ok := <-f.filesCh: + if !ok { + // channel was closed + err = io.EOF + break LOOP + } + + // set the next file entry + files[idx] = file + idx++ + + // record that we traversed a file + f.stats.Add(stats.Traversed, 1) + } + } + + return idx, err +} + +// Close waits for all filesystem processing to complete. +func (f *FilesystemReader) Close() error { + return f.eg.Wait() +} + +// NewFilesystemReader creates a new instance of FilesystemReader to traverse and read files from the specified paths +// and root. +func NewFilesystemReader( + root string, + paths []string, + statz *stats.Stats, + batchSize int, +) *FilesystemReader { + // if no paths are specified, we default to the root path + if len(paths) == 0 { + paths = []string{"."} + } + + // create an error group for managing the processing loop + eg := errgroup.Group{} + + r := FilesystemReader{ + log: log.WithPrefix("walk[filesystem]"), + root: root, + paths: paths, + batchSize: batchSize, + + eg: &eg, + + stats: statz, + filesCh: make(chan *File, batchSize*runtime.NumCPU()), + } + + // start processing loop + eg.Go(r.process) + + return &r } diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index b33404b8..02775a4d 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -1,9 +1,15 @@ -package walk +package walk_test import ( "context" - "os" + "errors" + "io" "testing" + "time" + + "github.com/numtide/treefmt/stats" + + "github.com/numtide/treefmt/walk" "github.com/numtide/treefmt/test" "github.com/stretchr/testify/require" @@ -44,33 +50,38 @@ var examplesPaths = []string{ "yaml/test.yaml", } -func TestFilesystemWalker_Walk(t *testing.T) { +func TestFilesystemReader(t *testing.T) { + as := require.New(t) + tempDir := test.TempExamples(t) + statz := stats.New() - paths := make(chan string, 1) - go func() { - paths <- tempDir - close(paths) - }() + r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) - as := require.New(t) + count := 0 + + for { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + + files := make([]*walk.File, 8) + n, err := r.Read(ctx, files) + + for i := count; i < count+n; i++ { + as.Equal(examplesPaths[i], files[i-count].RelPath) + } + + count += n + + cancel() + + if errors.Is(err, io.EOF) { + break + } + } - walker, err := NewFilesystem(tempDir, paths) - as.NoError(err) - - idx := 0 - err = walker.Walk(context.Background(), func(file *File, err error) error { - as.Equal(examplesPaths[idx], file.RelPath) - idx += 1 - return nil - }) - as.NoError(err) - - // capture current cwd, so we can replace it after the test is finished - cwd, err := os.Getwd() - as.NoError(err) - t.Cleanup(func() { - // return to the previous working directory - as.NoError(os.Chdir(cwd)) - }) + as.Equal(32, count) + as.Equal(int32(32), statz.Value(stats.Traversed)) + as.Equal(int32(0), statz.Value(stats.Emitted)) + as.Equal(int32(0), statz.Value(stats.Matched)) + as.Equal(int32(0), statz.Value(stats.Formatted)) } diff --git a/walk/filetree.go b/walk/filetree.go new file mode 100644 index 00000000..02139db4 --- /dev/null +++ b/walk/filetree.go @@ -0,0 +1,62 @@ +package walk + +import ( + "path/filepath" + "strings" + + "github.com/go-git/go-git/v5/plumbing/format/index" +) + +// filetree represents a hierarchical file structure with directories and files. +type filetree struct { + name string + entries map[string]*filetree +} + +// add inserts a file path into the filetree structure, creating necessary parent directories if they do not exist. +func (n *filetree) add(path []string) { + if len(path) == 0 { + return + } else if n.entries == nil { + n.entries = make(map[string]*filetree) + } + + name := path[0] + child, ok := n.entries[name] + if !ok { + child = &filetree{name: name} + n.entries[name] = child + } + child.add(path[1:]) +} + +// addPath splits the given path by the filepath separator and inserts it into the filetree structure. +func (n *filetree) addPath(path string) { + n.add(strings.Split(path, string(filepath.Separator))) +} + +// has returns true if the specified path exists in the filetree, false otherwise. +func (n *filetree) has(path []string) bool { + if len(path) == 0 { + return true + } else if len(n.entries) == 0 { + return false + } + child, ok := n.entries[path[0]] + if !ok { + return false + } + return child.has(path[1:]) +} + +// hasPath splits the given path by the filepath separator and checks if it exists in the filetree. +func (n *filetree) hasPath(path string) bool { + return n.has(strings.Split(path, string(filepath.Separator))) +} + +// readIndex traverses the index entries and adds each file path to the filetree structure. +func (n *filetree) readIndex(idx *index.Index) { + for _, entry := range idx.Entries { + n.addPath(entry.Name) + } +} diff --git a/walk/filetree_test.go b/walk/filetree_test.go new file mode 100644 index 00000000..6ef68d34 --- /dev/null +++ b/walk/filetree_test.go @@ -0,0 +1,31 @@ +package walk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFileTree(t *testing.T) { + as := require.New(t) + + node := &filetree{name: ""} + node.addPath("foo/bar/baz") + node.addPath("fizz/buzz") + node.addPath("hello/world") + node.addPath("foo/bar/fizz") + node.addPath("foo/fizz/baz") + + as.True(node.hasPath("foo")) + as.True(node.hasPath("foo/bar")) + as.True(node.hasPath("foo/bar/baz")) + as.True(node.hasPath("fizz")) + as.True(node.hasPath("fizz/buzz")) + as.True(node.hasPath("hello")) + as.True(node.hasPath("hello/world")) + as.True(node.hasPath("foo/bar/fizz")) + as.True(node.hasPath("foo/fizz/baz")) + + as.False(node.hasPath("fo")) + as.False(node.hasPath("world")) +} diff --git a/walk/git.go b/walk/git.go index 4e3c6505..3f29ca46 100644 --- a/walk/git.go +++ b/walk/git.go @@ -3,160 +3,111 @@ package walk import ( "context" "fmt" + "io" "io/fs" "os" "path/filepath" + "runtime" "strings" "github.com/charmbracelet/log" - "github.com/go-git/go-git/v5/plumbing/filemode" - "github.com/go-git/go-git/v5/plumbing/format/index" - "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/filemode" + "github.com/numtide/treefmt/stats" + "golang.org/x/sync/errgroup" ) -// fileTree represents a hierarchical file structure with directories and files. -type fileTree struct { - name string - entries map[string]*fileTree -} +type GitReader struct { + root string + paths []string + stats *stats.Stats + batchSize int -// add inserts a file path into the fileTree structure, creating necessary parent directories if they do not exist. -func (n *fileTree) add(path []string) { - if len(path) == 0 { - return - } else if n.entries == nil { - n.entries = make(map[string]*fileTree) - } - - name := path[0] - child, ok := n.entries[name] - if !ok { - child = &fileTree{name: name} - n.entries[name] = child - } - child.add(path[1:]) -} + log *log.Logger + repo *git.Repository -// addPath splits the given path by the filepath separator and inserts it into the fileTree structure. -func (n *fileTree) addPath(path string) { - n.add(strings.Split(path, string(filepath.Separator))) -} + filesCh chan *File -// has returns true if the specified path exists in the fileTree, false otherwise. -func (n *fileTree) has(path []string) bool { - if len(path) == 0 { - return true - } else if len(n.entries) == 0 { - return false - } - child, ok := n.entries[path[0]] - if !ok { - return false - } - return child.has(path[1:]) + eg *errgroup.Group } -// hasPath splits the given path by the filepath separator and checks if it exists in the fileTree. -func (n *fileTree) hasPath(path string) bool { - return n.has(strings.Split(path, string(filepath.Separator))) -} +func (g *GitReader) process() error { + defer func() { + close(g.filesCh) + }() -// readIndex traverses the index entries and adds each file path to the fileTree structure. -func (n *fileTree) readIndex(idx *index.Index) { - for _, entry := range idx.Entries { - n.addPath(entry.Name) - } -} - -type gitWalker struct { - log *log.Logger - root string - paths chan string - repo *git.Repository - relPathOffset int -} - -func (g gitWalker) Root() string { - return g.root -} - -func (g gitWalker) relPath(path string) (string, error) { // - return filepath.Rel(g.root, path) -} - -func (g gitWalker) Walk(ctx context.Context, fn WalkFunc) error { - idx, err := g.repo.Storer.Index() + gitIndex, err := g.repo.Storer.Index() if err != nil { return fmt.Errorf("failed to open git index: %w", err) } // if we need to walk a path that is not the root of the repository, we will read the directory structure of the // git index into memory for faster lookups - var cache *fileTree + var idxCache *filetree + + for pathIdx := range g.paths { + + path := filepath.Clean(filepath.Join(g.root, g.paths[pathIdx])) + if !strings.HasPrefix(path, g.root) { + return fmt.Errorf("path '%s' is outside of the root '%s'", path, g.root) + } - for path := range g.paths { switch path { case g.root: // we can just iterate the index entries - for _, entry := range idx.Entries { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // we only want regular files, not directories or symlinks - if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink { - continue - } - - // stat the file - path := filepath.Join(g.root, entry.Name) - - info, err := os.Lstat(path) - if os.IsNotExist(err) { - // the underlying file might have been removed without the change being staged yet - g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path) - continue - } else if err != nil { - return fmt.Errorf("failed to stat %s: %w", path, err) - } - - // determine a relative path - relPath, err := g.relPath(path) - if err != nil { - return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) - } - - file := File{ - Path: path, - RelPath: relPath, - Info: info, - } - - if err = fn(&file, err); err != nil { - return err - } + for _, entry := range gitIndex.Entries { + + // we only want regular files, not directories or symlinks + if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink { + continue + } + + // stat the file + path := filepath.Join(g.root, entry.Name) + + info, err := os.Lstat(path) + if os.IsNotExist(err) { + // the underlying file might have been removed without the change being staged yet + g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path) + continue + } else if err != nil { + return fmt.Errorf("failed to stat %s: %w", path, err) + } + + // determine a relative path + relPath, err := filepath.Rel(g.root, path) + if err != nil { + return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) + } + + file := File{ + Path: path, + RelPath: relPath, + Info: info, } + + g.stats.Add(stats.Traversed, 1) + g.filesCh <- &file } default: // read the git index into memory if it hasn't already - if cache == nil { - cache = &fileTree{name: ""} - cache.readIndex(idx) + if idxCache == nil { + idxCache = &filetree{name: ""} + idxCache.readIndex(gitIndex) } // git index entries are relative to the repository root, so we need to determine a relative path for the // one we are currently processing before checking if it exists within the git index - relPath, err := g.relPath(path) + relPath, err := filepath.Rel(g.root, path) if err != nil { return fmt.Errorf("failed to find root relative path for %v: %w", path, err) } - if !cache.hasPath(relPath) { + if !idxCache.hasPath(relPath) { log.Debugf("path %s not found in git index, skipping", relPath) continue } @@ -168,12 +119,12 @@ func (g gitWalker) Walk(ctx context.Context, fn WalkFunc) error { } // determine a path relative to g.root before checking presence in the git index - relPath, err := g.relPath(path) + relPath, err := filepath.Rel(g.root, path) if err != nil { return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) } - if !cache.hasPath(relPath) { + if !idxCache.hasPath(relPath) { log.Debugf("path %v not found in git index, skipping", relPath) return nil } @@ -184,7 +135,9 @@ func (g gitWalker) Walk(ctx context.Context, fn WalkFunc) error { Info: info, } - return fn(&file, err) + g.stats.Add(stats.Traversed, 1) + g.filesCh <- &file + return nil }) if err != nil { return fmt.Errorf("failed to walk %s: %w", path, err) @@ -195,16 +148,60 @@ func (g gitWalker) Walk(ctx context.Context, fn WalkFunc) error { return nil } -func NewGit(root string, paths chan string) (Walker, error) { +func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) { + idx := 0 + +LOOP: + for idx < len(files) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case file, ok := <-g.filesCh: + if !ok { + err = io.EOF + break LOOP + } + files[idx] = file + idx++ + } + } + + return idx, err +} + +func (g *GitReader) Close() error { + return g.eg.Wait() +} + +func NewGitReader( + root string, + paths []string, + statz *stats.Stats, + batchSize int, +) (*GitReader, error) { repo, err := git.PlainOpen(root) if err != nil { - return nil, fmt.Errorf("failed to open git repo: %w", err) + return nil, fmt.Errorf("failed to open git repository: %w", err) } - return &gitWalker{ - log: log.WithPrefix("walker[git]"), - root: root, - paths: paths, - repo: repo, - relPathOffset: len(root) + 1, - }, nil + + eg := &errgroup.Group{} + + if len(paths) == 0 { + paths = []string{"."} + } + + r := &GitReader{ + root: root, + paths: paths, + stats: statz, + batchSize: batchSize, + log: log.WithPrefix("walk[git]"), + repo: repo, + filesCh: make(chan *File, batchSize*runtime.NumCPU()), + eg: eg, + } + + eg.Go(r.process) + + return r, nil } diff --git a/walk/git_test.go b/walk/git_test.go index e557eae3..66dc645b 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -1,31 +1,68 @@ -package walk +package walk_test import ( + "context" + "errors" + "io" + "path" "testing" + "time" + "github.com/go-git/go-billy/v5/osfs" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/cache" + "github.com/go-git/go-git/v5/storage/filesystem" + "github.com/numtide/treefmt/stats" + "github.com/numtide/treefmt/test" + "github.com/numtide/treefmt/walk" "github.com/stretchr/testify/require" ) -func TestFileTree(t *testing.T) { +func TestGitReader(t *testing.T) { as := require.New(t) - node := &fileTree{name: ""} - node.addPath("foo/bar/baz") - node.addPath("fizz/buzz") - node.addPath("hello/world") - node.addPath("foo/bar/fizz") - node.addPath("foo/fizz/baz") - - as.True(node.hasPath("foo")) - as.True(node.hasPath("foo/bar")) - as.True(node.hasPath("foo/bar/baz")) - as.True(node.hasPath("fizz")) - as.True(node.hasPath("fizz/buzz")) - as.True(node.hasPath("hello")) - as.True(node.hasPath("hello/world")) - as.True(node.hasPath("foo/bar/fizz")) - as.True(node.hasPath("foo/fizz/baz")) - - as.False(node.hasPath("fo")) - as.False(node.hasPath("world")) + tempDir := test.TempExamples(t) + + // init a git repo + repo, err := git.Init( + filesystem.NewStorage( + osfs.New(path.Join(tempDir, ".git")), + cache.NewObjectLRUDefault(), + ), + osfs.New(tempDir), + ) + as.NoError(err, "failed to init git repository") + + // get worktree and add everything to it + wt, err := repo.Worktree() + as.NoError(err, "failed to get git worktree") + as.NoError(wt.AddGlob(".")) + + statz := stats.New() + + reader, err := walk.NewGitReader(tempDir, nil, &statz, 1024) + as.NoError(err) + + count := 0 + + for { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + + files := make([]*walk.File, 8) + n, err := reader.Read(ctx, files) + + count += n + + cancel() + + if errors.Is(err, io.EOF) { + break + } + } + + as.Equal(32, count) + as.Equal(int32(32), statz.Value(stats.Traversed)) + as.Equal(int32(0), statz.Value(stats.Emitted)) + as.Equal(int32(0), statz.Value(stats.Matched)) + as.Equal(int32(0), statz.Value(stats.Formatted)) } diff --git a/walk/walk.go b/walk/walk.go new file mode 100644 index 00000000..86c3c8ee --- /dev/null +++ b/walk/walk.go @@ -0,0 +1,117 @@ +package walk + +import ( + "context" + "fmt" + "io/fs" + "os" + "time" + + "github.com/numtide/treefmt/stats" + "github.com/numtide/treefmt/walk/cache" + bolt "go.etcd.io/bbolt" +) + +//go:generate enumer -type=Type -text -transform=snake -output=./type_enum.go +type Type int + +const ( + Auto Type = iota + Git + Filesystem + + BatchSize = 1024 +) + +// File represents a file object with its path, relative path, file info, and potential cached entry. +// It provides an optional release function to trigger a cache update after processing. +type File struct { + Path string + RelPath string + Info fs.FileInfo + + // Cache is the latest entry found for this file, if one exists. + Cache *cache.Entry + + // An optional function to be invoked when this File has finished processing. + // Typically used to trigger a cache update. + Release func() +} + +// Stat checks if the file has changed by comparing its current state (size, mod time) to when it was first read. +// It returns a boolean indicating if the file has changed, the current file info, and an error if any. +func (f File) Stat() (bool, fs.FileInfo, error) { + // Get the file's current state + current, err := os.Stat(f.Path) + if err != nil { + return false, nil, fmt.Errorf("failed to stat %s: %w", f.Path, err) + } + + // Check the size first + if f.Info.Size() != current.Size() { + return true, current, nil + } + + // POSIX specifies EPOCH time for Mod time, but some filesystems give more precision. + // Some formatters mess with the mod time (e.g. dos2unix) but not to the same precision, + // triggering false positives. + // We truncate everything below a second. + if f.Info.ModTime().Truncate(time.Second) != current.ModTime().Truncate(time.Second) { + return true, current, nil + } + + return false, nil, nil +} + +// String returns the file's path as a string. +func (f File) String() string { + return f.Path +} + +// Reader is an interface for reading files. +type Reader interface { + Read(ctx context.Context, files []*File) (n int, err error) + Close() error +} + +// NewReader creates a new instance of Reader based on the given walkType (Auto, Git, Filesystem). +func NewReader( + walkType Type, + root string, + paths []string, + db *bolt.DB, + statz *stats.Stats, +) (Reader, error) { + var ( + err error + reader Reader + ) + + switch walkType { + case Auto: + // for now, we keep it simple and try git first, filesystem second + reader, err = NewReader(Git, root, paths, db, statz) + if err != nil { + reader, err = NewReader(Filesystem, root, paths, db, statz) + } + return reader, err + case Git: + reader, err = NewGitReader(root, paths, statz, BatchSize) + case Filesystem: + reader = NewFilesystemReader(root, paths, statz, BatchSize) + default: + return nil, fmt.Errorf("unknown walk type: %v", walkType) + } + + if err != nil { + return nil, err + } + + if db != nil { + // wrap with cached reader + // db will be null if --no-cache is enabled + reader, err = NewCachedReader(db, BatchSize, reader) + } + + return reader, err +} diff --git a/walk/walker.go b/walk/walker.go deleted file mode 100644 index ebf3e670..00000000 --- a/walk/walker.go +++ /dev/null @@ -1,80 +0,0 @@ -package walk - -import ( - "context" - "fmt" - "io/fs" - "os" - "time" -) - -//go:generate enumer -type=Type -text -transform=snake -output=./type_enum.go -type Type int - -const ( - Auto Type = iota - Git - Filesystem -) - -type File struct { - Path string - RelPath string - Info fs.FileInfo -} - -func (f File) HasChanged() (bool, fs.FileInfo, error) { - // get the file's current state - current, err := os.Stat(f.Path) - if err != nil { - return false, nil, fmt.Errorf("failed to stat %s: %w", f.Path, err) - } - - // check the size first - if f.Info.Size() != current.Size() { - return true, current, nil - } - - // POSIX specifies EPOCH time for Mod time, but some filesystems give more precision. - // Some formatters mess with the mod time (e.g. dos2unix) but not to the same precision, - // triggering false positives. - // We truncate everything below a second. - if f.Info.ModTime().Truncate(time.Second) != current.ModTime().Truncate(time.Second) { - return true, current, nil - } - - return false, nil, nil -} - -func (f File) String() string { - return f.Path -} - -type WalkFunc func(file *File, err error) error - -type Walker interface { - Root() string - Walk(ctx context.Context, fn WalkFunc) error -} - -func New(walkerType Type, root string, pathsCh chan string) (Walker, error) { - switch walkerType { - case Git: - return NewGit(root, pathsCh) - case Auto: - return Detect(root, pathsCh) - case Filesystem: - return NewFilesystem(root, pathsCh) - default: - return nil, fmt.Errorf("unknown walker type: %v", walkerType) - } -} - -func Detect(root string, pathsCh chan string) (Walker, error) { - // for now, we keep it simple and try git first, filesystem second - w, err := NewGit(root, pathsCh) - if err == nil { - return w, err - } - return NewFilesystem(root, pathsCh) -}