diff --git a/cmd/format/format.go b/cmd/format/format.go index f3cd281b..40dcc094 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -8,14 +8,12 @@ import ( "os" "os/signal" "path/filepath" - "runtime" "runtime/pprof" "strings" "syscall" "time" "github.com/charmbracelet/log" - "github.com/gobwas/glob" "github.com/numtide/treefmt/config" "github.com/numtide/treefmt/format" "github.com/numtide/treefmt/stats" @@ -24,18 +22,13 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" bolt "go.etcd.io/bbolt" - "golang.org/x/sync/errgroup" - "mvdan.cc/sh/v3/expand" ) const ( BatchSize = 1024 ) -var ( - ErrFailOnChange = errors.New("unexpected changes detected, --fail-on-change is enabled") - ErrFormattingFailures = errors.New("formatting failures detected") -) +var ErrFailOnChange = errors.New("unexpected changes detected, --fail-on-change is enabled") func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) error { cmd.SilenceUsage = true @@ -63,7 +56,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) // This can fail in CI between checkout and running treefmt if everything happens too quickly. // For humans, the second level precision should not be a problem as they are unlikely to run treefmt in // sub-second succession. - <-time.After(time.Until(startAfter)) + time.Sleep(time.Until(startAfter)) } // cpu profiling @@ -84,35 +77,9 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) }() } - // create a prefixed logger + // set a prefix on the default logger log.SetPrefix("format") - // compile global exclude globs - globalExcludes, err := format.CompileGlobs(cfg.Excludes) - if err != nil { - return fmt.Errorf("failed to compile global excludes: %w", err) - } - - // initialise formatters - formatters := make(map[string]*format.Formatter) - - env := expand.ListEnviron(os.Environ()...) - - for name, formatterCfg := range cfg.FormatterConfigs { - formatter, err := format.NewFormatter(name, cfg.TreeRoot, env, formatterCfg) - - if errors.Is(err, format.ErrCommandNotFound) && cfg.AllowMissingFormatter { - log.Debugf("formatter command not found: %v", name) - - continue - } else if err != nil { - return fmt.Errorf("%w: failed to initialise formatter: %v", err, name) - } - - // store formatter by name - formatters[name] = formatter - } - var db *bolt.DB // open the db unless --no-cache was specified @@ -137,17 +104,13 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) return fmt.Errorf("failed to clear cache: %w", err) } } - - // Compare formatters, clearing paths if they have changed, and recording their latest info in the db - if err = format.CompareFormatters(db, formatters); err != nil { - return fmt.Errorf("failed to compare formatters: %w", err) - } } - // create an app context and listen for shutdown + // create an overall app context ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // listen for shutdown signal and cancel the context go func() { exit := make(chan os.Signal, 1) signal.Notify(exit, os.Interrupt, syscall.SIGTERM) @@ -155,20 +118,6 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) cancel() }() - // create an overall error group for executing high level tasks concurrently - eg, ctx := errgroup.WithContext(ctx) - - // create a channel for files needing to be processed - // we use a multiple of batch size here as a rudimentary concurrency optimization based on the host machine - filesCh := make(chan *walk.File, BatchSize*runtime.NumCPU()) - - // create a channel for files that have been formatted - formattedCh := make(chan *format.Task, cap(filesCh)) - - // start concurrent processing tasks in reverse order - eg.Go(postProcessing(ctx, cfg, statz, formattedCh)) - eg.Go(applyFormatters(ctx, cfg, statz, globalExcludes, formatters, filesCh, formattedCh)) - // parse the walk type walkType, err := walk.TypeString(cfg.Walk) if err != nil { @@ -206,8 +155,21 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) } } - // create a new reader for traversing the paths - reader, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz) + // create a composite formatter which will handle applying the correct formatters to each file we traverse + formatter, err := format.NewCompositeFormatter(cfg, statz, BatchSize) + if err != nil { + return fmt.Errorf("failed to create composite formatter: %w", err) + } + + if db != nil { + // compare formatters with the db, busting the cache if the formatters have changed + if err := formatter.BustCache(db); err != nil { + return fmt.Errorf("failed to compare formatters: %w", err) + } + } + + // create a new walker for traversing the paths + walker, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz) if err != nil { return fmt.Errorf("failed to create walker: %w", err) } @@ -217,15 +179,15 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) for { // read the next batch - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) - n, err := reader.Read(ctx, files) + readCtx, cancel := context.WithTimeout(ctx, 1*time.Second) + n, err := walker.Read(readCtx, files) // ensure context is cancelled to release resources cancel() - // pass each file into the file channel for processing - for idx := 0; idx < n; idx++ { - filesCh <- files[idx] + // format + if err := formatter.Apply(ctx, files[:n]); err != nil { + return fmt.Errorf("formatting failure: %w", err) } if errors.Is(err, io.EOF) { @@ -237,261 +199,26 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) } } - // indicate no further files for processing - close(filesCh) + // finalize formatting + formatErr := formatter.Close(ctx) - // wait for everything to complete - if err = eg.Wait(); err != nil { - return err + // close the walker, ensuring any pending file release hooks finish + if err = walker.Close(); err != nil { + return fmt.Errorf("failed to close walker: %w", err) } - return reader.Close() -} - -func applyFormatters( - ctx context.Context, - cfg *config.Config, - statz *stats.Stats, - globalExcludes []glob.Glob, - formatters map[string]*format.Formatter, - filesCh chan *walk.File, - formattedCh chan *format.Task, -) func() error { - // create our own errgroup for concurrent formatting tasks. - // we don't want a cancel clause, in order to let formatters run up to the end. - fg := errgroup.Group{} - // simple optimization to avoid too many concurrent formatting tasks - // we can queue them up faster than the formatters can process them, this paces things a bit - fg.SetLimit(runtime.NumCPU()) - - // track batches of formatting task based on their batch keys, which are determined by the unique sequence of - // formatters which should be applied to their respective files - batches := make(map[string][]*format.Task) - - // apply check if the given batch key has enough tasks to trigger processing - // flush is used to force processing regardless of the number of tasks - apply := func(key string, flush bool) { - // lookup the batch and exit early if it's empty - batch := batches[key] - if len(batch) == 0 { - return - } - - // process the batch if it's full, or we've been asked to flush partial batches - if flush || len(batch) == BatchSize { - // copy the batch as we re-use it for the next batch - tasks := make([]*format.Task, len(batch)) - copy(tasks, batch) - - // asynchronously apply the sequence formatters to the batch - fg.Go(func() error { - // Iterate the formatters, applying them in sequence to the batch of tasks. - // We get the formatter list from the first task since they have all the same formatters list. - formatters := tasks[0].Formatters - - var formatErrors []error - - for idx := range formatters { - if err := formatters[idx].Apply(ctx, tasks); err != nil { - formatErrors = append(formatErrors, err) - } - } - - // pass each file to the formatted channel - for _, task := range tasks { - task.Errors = formatErrors - formattedCh <- task - } - - return nil - }) - - // reset the batch - batches[key] = batch[:0] - } + // print stats to stdout, unless we are processing from stdin and therefore outputting the results to stdout + if !cfg.Stdin { + statz.Print() } - // tryApply batches tasks by their batch key and processes the batch if there is enough ready - tryApply := func(task *format.Task) { - // append to batch - key := task.BatchKey - batches[key] = append(batches[key], task) - // try to apply - apply(key, false) - } - - return func() error { - defer func() { - // indicate processing has finished - close(formattedCh) - }() - - // parse unmatched log level - unmatchedLevel, err := log.ParseLevel(cfg.OnUnmatched) - if err != nil { - return fmt.Errorf("invalid on-unmatched value: %w", err) - } - - // iterate the file channel - for file := range filesCh { - // a list of formatters that match this file - var matches []*format.Formatter - - // first check if this file has been globally excluded - if format.PathMatches(file.RelPath, globalExcludes) { - log.Debugf("path matched global excludes: %s", file.RelPath) - } else { - // otherwise, check if any formatters are interested in it - for _, formatter := range formatters { - if formatter.Wants(file) { - matches = append(matches, formatter) - } - } - } - - // indicates no further processing - var release bool - - // check if there were no matches - if len(matches) == 0 { - // log that there was no match, exiting with an error if the unmatched level was set to fatal - if unmatchedLevel == log.FatalLevel { - return fmt.Errorf("no formatter for path: %s", file.RelPath) - } - - log.Logf(unmatchedLevel, "no formatter for path: %s", file.RelPath) - - // no further processing - release = true - } else { - // record there was a match - statz.Add(stats.Matched, 1) - - // check if the file is new or has changed when compared to the cache entry - if file.Cache == nil || file.Cache.HasChanged(file.Info) { - // if so, generate a format task, add it to the relevant batch (by batch key) and try to process - task := format.NewTask(file, matches) - tryApply(&task) - } else { - // indicate no further processing - release = true - } - } - - if release { - // release the file as there is no more processing to be done on it - if err := file.Release(nil); err != nil { - return fmt.Errorf("failed to release file: %w", err) - } - } - } - - // flush any partial batches which remain - for key := range batches { - apply(key, true) - } - - // wait for all outstanding formatting tasks to complete - if err := fg.Wait(); err != nil { - return fmt.Errorf("formatting failure: %w", err) - } - - return nil - } -} - -func postProcessing( - ctx context.Context, - cfg *config.Config, - statz *stats.Stats, - formattedCh chan *format.Task, -) func() error { - return func() error { - var formattingFailures bool // track if there were any formatting failures - - LOOP: - for { - select { - // detect ctx cancellation - case <-ctx.Done(): - return ctx.Err() - - // take the next task that has been processed - case task, ok := <-formattedCh: - if !ok { - break LOOP - } - - // grab the underlying file reference - file := task.File - - if len(task.Errors) > 0 { - formattingFailures = true - - // release the file, passing the first task error - // note: task errors are related to the batch in which a task was applied - // this does not necessarily indicate this file had a problem being formatted, but this approach - // serves our purpose for now of indicating some sort of error condition to the release hooks - if err := file.Release(task.Errors[0]); err != nil { - return fmt.Errorf("failed to release file: %w", err) - } - - // continue processing next task - continue - } - - // check if the file has changed - changed, newInfo, err := file.Stat() - if err != nil { - return err - } - - statz.Add(stats.Formatted, 1) - - if changed { - // record that a change in the underlying file occurred - statz.Add(stats.Changed, 1) - - logMethod := log.Debug - if cfg.FailOnChange { - // surface the changed file more obviously - logMethod = log.Error - } - - // log the change - logMethod( - "file has changed", - "path", file.RelPath, - "prev_size", file.Info.Size(), - "prev_mod_time", file.Info.ModTime().Truncate(time.Second), - "current_size", newInfo.Size(), - "current_mod_time", newInfo.ModTime().Truncate(time.Second), - ) - // update the file info - file.Info = newInfo - } - - if err := file.Release(nil); err != nil { - return fmt.Errorf("failed to release file: %w", err) - } - } - } - - // print stats to stdout unless we are processing stdin and printing the results to stdout - if !cfg.Stdin { - statz.Print() - } - + if formatErr != nil { // return an error if any formatting failures were detected - if formattingFailures { - return ErrFormattingFailures - } - + return formatErr + } else if cfg.FailOnChange && statz.Value(stats.Changed) != 0 { // if fail on change has been enabled, check that no files were actually changed, throwing an error if so - if cfg.FailOnChange && statz.Value(stats.Changed) != 0 { - return ErrFailOnChange - } - - return nil + return ErrFailOnChange } + + return nil } diff --git a/cmd/root_test.go b/cmd/root_test.go index e1303caf..1425dd32 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -483,7 +483,7 @@ func TestCache(t *testing.T) { // running should match but not format anything _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) - as.ErrorIs(err, formatCmd.ErrFormattingFailures) + as.ErrorIs(err, format.ErrFormattingFailures) assertStats(t, as, statz, map[stats.Type]int{ stats.Traversed: 32, @@ -494,7 +494,7 @@ func TestCache(t *testing.T) { // running again should provide the same result _, statz, err = treefmt(t, "--config-file", configPath, "--tree-root", tempDir) - as.ErrorIs(err, formatCmd.ErrFormattingFailures) + as.ErrorIs(err, format.ErrFormattingFailures) assertStats(t, as, statz, map[stats.Type]int{ stats.Traversed: 32, diff --git a/format/cache.go b/format/cache.go deleted file mode 100644 index 2d6e3e61..00000000 --- a/format/cache.go +++ /dev/null @@ -1,94 +0,0 @@ -package format - -import ( - "errors" - "fmt" - "os" - - "github.com/charmbracelet/log" - "github.com/numtide/treefmt/walk/cache" - bolt "go.etcd.io/bbolt" -) - -func CompareFormatters(db *bolt.DB, formatters map[string]*Formatter) error { - return db.Update(func(tx *bolt.Tx) error { - clearPaths := false - - pathsBucket, err := cache.BucketPaths(tx) - if err != nil { - return fmt.Errorf("failed to get paths bucket from cache: %w", err) - } - - formattersBucket, err := cache.BucketFormatters(tx) - if err != nil { - return fmt.Errorf("failed to get formatters bucket from cache: %w", err) - } - - // check for any newly configured or modified formatters - for name, formatter := range formatters { - stat, err := os.Lstat(formatter.Executable()) - if err != nil { - return fmt.Errorf("failed to stat formatter executable %v: %w", formatter.Executable(), err) - } - - entry, err := formattersBucket.Get(name) - if !(err == nil || errors.Is(err, cache.ErrKeyNotFound)) { - return fmt.Errorf("failed to retrieve cache entry for formatter %v: %w", name, err) - } - - isNew := errors.Is(err, cache.ErrKeyNotFound) - hasChanged := !(isNew || (entry.Size == stat.Size() && entry.Modified == stat.ModTime())) - - if isNew { - log.Debugf("formatter '%s' is new", name) - } else if hasChanged { - log.Debug("formatter '%s' has changed", - name, - "size", stat.Size(), - "modTime", stat.ModTime(), - "cachedSize", entry.Size, - "cachedModTime", entry.Modified, - ) - } - - // update overall flag - clearPaths = clearPaths || isNew || hasChanged - - // record formatters info - entry = &cache.Entry{ - Size: stat.Size(), - Modified: stat.ModTime(), - } - - if err = formattersBucket.Put(name, entry); err != nil { - return fmt.Errorf("failed to write cache entry for formatter %v: %w", name, err) - } - } - - // check for any removed formatters - if err = formattersBucket.ForEach(func(key string, _ *cache.Entry) error { - _, ok := formatters[key] - if !ok { - // remove the formatter entry from the cache - if err = formattersBucket.Delete(key); err != nil { - return fmt.Errorf("failed to remove cache entry for formatter %v: %w", key, err) - } - // indicate a clean is required - clearPaths = true - } - - return nil - }); err != nil { - return fmt.Errorf("failed to check cache for removed formatters: %w", err) - } - - if clearPaths { - // remove all path entries - if err := pathsBucket.DeleteAll(); err != nil { - return fmt.Errorf("failed to remove all path entries from cache: %w", err) - } - } - - return nil - }) -} diff --git a/format/format.go b/format/format.go new file mode 100644 index 00000000..86528122 --- /dev/null +++ b/format/format.go @@ -0,0 +1,407 @@ +package format + +import ( + "cmp" + "context" + "errors" + "fmt" + "os" + "runtime" + "slices" + "strings" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/gobwas/glob" + "github.com/numtide/treefmt/config" + "github.com/numtide/treefmt/stats" + "github.com/numtide/treefmt/walk" + "github.com/numtide/treefmt/walk/cache" + bolt "go.etcd.io/bbolt" + "golang.org/x/sync/errgroup" + "mvdan.cc/sh/v3/expand" +) + +const ( + batchKeySeparator = ":" +) + +var ErrFormattingFailures = errors.New("formatting failures detected") + +// batchKey represents the unique sequence of formatters to be applied to a batch of files. +// For example, "deadnix:statix:nixpkgs-fmt" indicates that deadnix should be applied first, statix second and +// nixpkgs-fmt third. +// Files are batched based on their formatting sequence, as determined by the priority and includes/excludes in the +// formatter configuration. +type batchKey string + +// sequence returns the list of formatters, by name, to be applied to a batch of files. +func (b batchKey) sequence() []string { + return strings.Split(string(b), batchKeySeparator) +} + +func newBatchKey(formatters []*Formatter) batchKey { + components := make([]string, 0, len(formatters)) + for _, f := range formatters { + components = append(components, f.Name()) + } + + return batchKey(strings.Join(components, batchKeySeparator)) +} + +// batchMap maintains a mapping between batchKey and a slice of pointers to walk.File, used to organize files into +// batches based on the sequence of formatters to be applied. +type batchMap map[batchKey][]*walk.File + +func formatterSortFunc(a, b *Formatter) int { + // sort by priority in ascending order + priorityA := a.Priority() + priorityB := b.Priority() + + result := priorityA - priorityB + if result == 0 { + // formatters with the same priority are sorted lexicographically to ensure a deterministic outcome + result = cmp.Compare(a.Name(), b.Name()) + } + + return result +} + +// Append adds a file to the batch corresponding to the given sequence of formatters and returns the updated batch. +func (b batchMap) Append(file *walk.File, matches []*Formatter) (key batchKey, batch []*walk.File) { + slices.SortFunc(matches, formatterSortFunc) + + // construct a batch key based on the sequence of formatters + key = newBatchKey(matches) + + // append to the batch + b[key] = append(b[key], file) + + // return the batch + return key, b[key] +} + +// CompositeFormatter handles the application of multiple Formatter instances based on global excludes and individual +// formatter configuration. +type CompositeFormatter struct { + stats *stats.Stats + batchSize int + globalExcludes []glob.Glob + + changeLevel log.Level + unmatchedLevel log.Level + + formatters map[string]*Formatter + + eg *errgroup.Group + batches batchMap + + // formatError indicates if at least one formatting error occurred + formatError *atomic.Bool +} + +func (c *CompositeFormatter) apply(ctx context.Context, key batchKey, batch []*walk.File) { + c.eg.Go(func() error { + var formatErrors []error + + // apply the formatters in sequence + for _, name := range key.sequence() { + formatter := c.formatters[name] + + if err := formatter.Apply(ctx, batch); err != nil { + formatErrors = append(formatErrors, err) + } + } + + // record if a format error occurred + hasErrors := len(formatErrors) > 0 + c.formatError.Store(hasErrors) + + if !hasErrors { + // record that the file was formatted + c.stats.Add(stats.Formatted, len(batch)) + } + + // Create a release context. + // We set no-cache based on whether any formatting errors occurred in this batch. + // This is to communicate with any caching layer, if used when reading files for this batch, that it should not + // update the state of any file in this batch, as we want to re-process them in later invocations. + releaseCtx := walk.SetNoCache(ctx, hasErrors) + + // post-processing + for _, file := range batch { + // check if the file has changed + changed, newInfo, err := file.Stat() + if err != nil { + return err + } + + if changed { + // record that a change in the underlying file occurred + c.stats.Add(stats.Changed, 1) + + log.Log( + c.changeLevel, "file has changed", + "path", file.RelPath, + "prev_size", file.Info.Size(), + "prev_mod_time", file.Info.ModTime().Truncate(time.Second), + "current_size", newInfo.Size(), + "current_mod_time", newInfo.ModTime().Truncate(time.Second), + ) + + // update the file info + file.Info = newInfo + } + + // release the file as there is no further processing to be done on it + if err := file.Release(releaseCtx); err != nil { + return fmt.Errorf("failed to release file: %w", err) + } + } + + return nil + }) +} + +// match filters the file against global excludes and returns a list of formatters that want to process the file. +func (c *CompositeFormatter) match(file *walk.File) []*Formatter { + // first check if this file has been globally excluded + if pathMatches(file.RelPath, c.globalExcludes) { + log.Debugf("path matched global excludes: %s", file.RelPath) + + return nil + } + + // a list of formatters that match this file + var matches []*Formatter + + // otherwise, check if any formatters are interested in it + for _, formatter := range c.formatters { + if formatter.Wants(file) { + matches = append(matches, formatter) + } + } + + return matches +} + +// Apply applies the configured formatters to the given files. +func (c *CompositeFormatter) Apply(ctx context.Context, files []*walk.File) error { + var toRelease []*walk.File + + for _, file := range files { + matches := c.match(file) // match the file against the formatters + + // check if there were no matches + if len(matches) == 0 { + // log that there was no match, exiting with an error if the unmatched level was set to fatal + if c.unmatchedLevel == log.FatalLevel { + return fmt.Errorf("no formatter for path: %s", file.RelPath) + } + + log.Logf(c.unmatchedLevel, "no formatter for path: %s", file.RelPath) + + // no further processing to be done, append to the release list + toRelease = append(toRelease, file) + + // continue to the next file + continue + } + + // record there was a match + c.stats.Add(stats.Matched, 1) + + // check if the file is new or has changed when compared to the cache entry + if file.Cache == nil || file.Cache.HasChanged(file.Info) { + // add this file to a batch and if it's full, apply formatters to the batch + if key, batch := c.batches.Append(file, matches); len(batch) == c.batchSize { + c.apply(ctx, newBatchKey(matches), batch) + // reset the batch + c.batches[key] = make([]*walk.File, 0, c.batchSize) + } + } else { + // no further processing to be done, append to the release list + toRelease = append(toRelease, file) + } + } + + // release files that require no further processing + // we set noCache to true as there's no need to update the cache, since we skipped those files + releaseCtx := walk.SetNoCache(ctx, true) + + for _, file := range toRelease { + if err := file.Release(releaseCtx); err != nil { + return fmt.Errorf("failed to release file: %w", err) + } + } + + return nil +} + +// BustCache compares the currently configured formatters with their respective entries in the db. +// If a formatter was added, removed or modified, we clear any path entries from the cache, ensuring that all paths +// get formatted with the most recent formatter set. +func (c *CompositeFormatter) BustCache(db *bolt.DB) error { + return db.Update(func(tx *bolt.Tx) error { + clearPaths := false + + pathsBucket, err := cache.BucketPaths(tx) + if err != nil { + return fmt.Errorf("failed to get paths bucket from cache: %w", err) + } + + formattersBucket, err := cache.BucketFormatters(tx) + if err != nil { + return fmt.Errorf("failed to get formatters bucket from cache: %w", err) + } + + // check for any newly configured or modified formatters + for name, formatter := range c.formatters { + stat, err := os.Lstat(formatter.Executable()) + if err != nil { + return fmt.Errorf("failed to stat formatter executable %v: %w", formatter.Executable(), err) + } + + entry, err := formattersBucket.Get(name) + if !(err == nil || errors.Is(err, cache.ErrKeyNotFound)) { + return fmt.Errorf("failed to retrieve cache entry for formatter %v: %w", name, err) + } + + isNew := errors.Is(err, cache.ErrKeyNotFound) + hasChanged := !(isNew || (entry.Size == stat.Size() && entry.Modified == stat.ModTime())) + + if isNew { + log.Debugf("formatter '%s' is new", name) + } else if hasChanged { + log.Debug("formatter '%s' has changed", + name, + "size", stat.Size(), + "modTime", stat.ModTime(), + "cachedSize", entry.Size, + "cachedModTime", entry.Modified, + ) + } + + // update overall flag + clearPaths = clearPaths || isNew || hasChanged + + // record formatters info + entry = &cache.Entry{ + Size: stat.Size(), + Modified: stat.ModTime(), + } + + if err = formattersBucket.Put(name, entry); err != nil { + return fmt.Errorf("failed to write cache entry for formatter %v: %w", name, err) + } + } + + // check for any removed formatters + if err = formattersBucket.ForEach(func(key string, _ *cache.Entry) error { + _, ok := c.formatters[key] + if !ok { + // remove the formatter entry from the cache + if err = formattersBucket.Delete(key); err != nil { + return fmt.Errorf("failed to remove cache entry for formatter %v: %w", key, err) + } + // indicate a clean is required + clearPaths = true + } + + return nil + }); err != nil { + return fmt.Errorf("failed to check cache for removed formatters: %w", err) + } + + if clearPaths { + // remove all path entries + if err := pathsBucket.DeleteAll(); err != nil { + return fmt.Errorf("failed to remove all path entries from cache: %w", err) + } + } + + return nil + }) +} + +// Close finalizes the processing of the CompositeFormatter, ensuring that any remaining batches are applied and +// all formatters have completed their tasks. It returns an error if any formatting failures were detected. +func (c *CompositeFormatter) Close(ctx context.Context) error { + // flush any partial batches that remain + for key, batch := range c.batches { + if len(batch) > 0 { + c.apply(ctx, key, batch) + } + } + + // wait for processing to complete + if err := c.eg.Wait(); err != nil { + return fmt.Errorf("failed to wait for formatters: %w", err) + } else if c.formatError.Load() { + return ErrFormattingFailures + } + + return nil +} + +func NewCompositeFormatter( + cfg *config.Config, + statz *stats.Stats, + batchSize int, +) (*CompositeFormatter, error) { + // compile global exclude globs + globalExcludes, err := compileGlobs(cfg.Excludes) + if err != nil { + return nil, fmt.Errorf("failed to compile global excludes: %w", err) + } + + // parse unmatched log level + unmatchedLevel, err := log.ParseLevel(cfg.OnUnmatched) + if err != nil { + return nil, fmt.Errorf("invalid on-unmatched value: %w", err) + } + + // create a composite formatter, adjusting the change logging based on --fail-on-change + changeLevel := log.DebugLevel + if cfg.FailOnChange { + changeLevel = log.ErrorLevel + } + + // create formatters + formatters := make(map[string]*Formatter) + + env := expand.ListEnviron(os.Environ()...) + + for name, formatterCfg := range cfg.FormatterConfigs { + formatter, err := newFormatter(name, cfg.TreeRoot, env, formatterCfg) + + if errors.Is(err, ErrCommandNotFound) && cfg.AllowMissingFormatter { + log.Debugf("formatter command not found: %v", name) + + continue + } else if err != nil { + return nil, fmt.Errorf("failed to initialise formatter %v: %w", name, err) + } + + // store formatter by name + formatters[name] = formatter + } + + // create an errgroup for asynchronously formatting + eg := errgroup.Group{} + // we use a simple heuristic to avoid too much contention by limiting the concurrency to runtime.NumCPU() + eg.SetLimit(runtime.NumCPU()) + + return &CompositeFormatter{ + stats: statz, + batchSize: batchSize, + globalExcludes: globalExcludes, + changeLevel: changeLevel, + unmatchedLevel: unmatchedLevel, + formatters: formatters, + eg: &eg, + batches: make(batchMap), + formatError: new(atomic.Bool), + }, nil +} diff --git a/format/formatter.go b/format/formatter.go index 41648b5f..5254565d 100644 --- a/format/formatter.go +++ b/format/formatter.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "regexp" "time" "github.com/charmbracelet/log" @@ -16,8 +17,13 @@ import ( "mvdan.cc/sh/v3/interp" ) -// ErrCommandNotFound is returned when the Command for a Formatter is not available. -var ErrCommandNotFound = errors.New("formatter command not found in PATH") +var ( + ErrInvalidName = errors.New("formatter name must only contain alphanumeric characters, `_` or `-`") + // ErrCommandNotFound is returned when the Command for a Formatter is not available. + ErrCommandNotFound = errors.New("formatter command not found in PATH") + + nameRegex = regexp.MustCompile("^[a-zA-Z0-9_-]+$") +) // Formatter represents a command which should be applied to a filesystem. type Formatter struct { @@ -28,16 +34,11 @@ type Formatter struct { executable string // path to the executable described by Command workingDir string - // internal compiled versions of Includes and Excludes. + // internal, compiled versions of Includes and Excludes. includes []glob.Glob excludes []glob.Glob } -// Executable returns the path to the executable defined by Command. -func (f *Formatter) Executable() string { - return f.executable -} - func (f *Formatter) Name() string { return f.name } @@ -46,20 +47,25 @@ func (f *Formatter) Priority() int { return f.config.Priority } -func (f *Formatter) Apply(ctx context.Context, tasks []*Task) error { +// Executable returns the path to the executable defined by Command. +func (f *Formatter) Executable() string { + return f.executable +} + +func (f *Formatter) Apply(ctx context.Context, files []*walk.File) error { start := time.Now() // construct args, starting with config args := f.config.Options // exit early if nothing to process - if len(tasks) == 0 { + if len(files) == 0 { return nil } // append paths to the args - for _, task := range tasks { - args = append(args, task.File.RelPath) + for _, file := range files { + args = append(args, file.RelPath) } // execute the command @@ -83,15 +89,16 @@ func (f *Formatter) Apply(ctx context.Context, tasks []*Task) error { return fmt.Errorf("formatter '%s' with options '%v' failed to apply: %w", f.config.Command, f.config.Options, err) } - f.log.Infof("%v file(s) processed in %v", len(tasks), time.Since(start)) + f.log.Infof("%v file(s) processed in %v", len(files), time.Since(start)) return nil } -// Wants is used to test if a Formatter wants a path based on it's configured Includes and Excludes patterns. -// Returns true if the Formatter should be applied to path, false otherwise. +// Wants is used to determine if a Formatter wants to process a path based on it's configured Includes and Excludes +// patterns. +// Returns true if the Formatter should be applied to file, false otherwise. func (f *Formatter) Wants(file *walk.File) bool { - match := !PathMatches(file.RelPath, f.excludes) && PathMatches(file.RelPath, f.includes) + match := !pathMatches(file.RelPath, f.excludes) && pathMatches(file.RelPath, f.includes) if match { f.log.Debugf("match: %v", file) } @@ -99,8 +106,8 @@ func (f *Formatter) Wants(file *walk.File) bool { return match } -// NewFormatter is used to create a new Formatter. -func NewFormatter( +// newFormatter is used to create a new Formatter. +func newFormatter( name string, treeRoot string, env expand.Environ, @@ -108,6 +115,11 @@ func NewFormatter( ) (*Formatter, error) { var err error + // check the name is valid + if !nameRegex.MatchString(name) { + return nil, ErrInvalidName + } + f := Formatter{} // capture config and the formatter's name @@ -130,12 +142,12 @@ func NewFormatter( f.log = log.WithPrefix(fmt.Sprintf("format | %s", name)) } - f.includes, err = CompileGlobs(cfg.Includes) + f.includes, err = compileGlobs(cfg.Includes) if err != nil { return nil, fmt.Errorf("failed to compile formatter '%v' includes: %w", f.name, err) } - f.excludes, err = CompileGlobs(cfg.Excludes) + f.excludes, err = compileGlobs(cfg.Excludes) if err != nil { return nil, fmt.Errorf("failed to compile formatter '%v' excludes: %w", f.name, err) } diff --git a/format/formatter_test.go b/format/formatter_test.go new file mode 100644 index 00000000..795ed4f2 --- /dev/null +++ b/format/formatter_test.go @@ -0,0 +1,49 @@ +package format_test + +import ( + "testing" + + "github.com/numtide/treefmt/config" + "github.com/numtide/treefmt/format" + "github.com/numtide/treefmt/stats" + "github.com/stretchr/testify/require" +) + +func TestInvalidFormatterName(t *testing.T) { + as := require.New(t) + + const batchSize = 1024 + + cfg := &config.Config{} + cfg.OnUnmatched = "info" + + statz := stats.New() + + // simple "empty" config + _, err := format.NewCompositeFormatter(cfg, &statz, batchSize) + as.NoError(err) + + // valid name using all the acceptable characters + cfg.FormatterConfigs = map[string]*config.Formatter{ + "echo_command-1234567890": { + Command: "echo", + }, + } + + _, err = format.NewCompositeFormatter(cfg, &statz, batchSize) + as.NoError(err) + + // test with some bad examples + for _, character := range []string{ + " ", ":", "?", "*", "[", "]", "(", ")", "|", "&", "<", ">", "\\", "/", "%", "$", "#", "@", "`", "'", + } { + cfg.FormatterConfigs = map[string]*config.Formatter{ + "touch_" + character: { + Command: "touch", + }, + } + + _, err = format.NewCompositeFormatter(cfg, &statz, batchSize) + as.ErrorIs(err, format.ErrInvalidName) + } +} diff --git a/format/glob.go b/format/glob.go index 296bd23a..c7104e30 100644 --- a/format/glob.go +++ b/format/glob.go @@ -6,8 +6,8 @@ import ( "github.com/gobwas/glob" ) -// CompileGlobs prepares the globs, where the patterns are all right-matching. -func CompileGlobs(patterns []string) ([]glob.Glob, error) { +// compileGlobs prepares the globs, where the patterns are all right-matching. +func compileGlobs(patterns []string) ([]glob.Glob, error) { globs := make([]glob.Glob, len(patterns)) for i, pattern := range patterns { @@ -22,7 +22,7 @@ func CompileGlobs(patterns []string) ([]glob.Glob, error) { return globs, nil } -func PathMatches(path string, globs []glob.Glob) bool { +func pathMatches(path string, globs []glob.Glob) bool { for idx := range globs { if globs[idx].Match(path) { return true diff --git a/format/glob_test.go b/format/glob_test.go index 6a85d3c4..03e44bc9 100644 --- a/format/glob_test.go +++ b/format/glob_test.go @@ -1,10 +1,10 @@ -package format_test +//nolint:testpackage +package format import ( "testing" "github.com/gobwas/glob" - "github.com/numtide/treefmt/format" "github.com/stretchr/testify/require" ) @@ -17,24 +17,24 @@ func TestGlobs(t *testing.T) { ) // File extension - globs, err = format.CompileGlobs([]string{"*.txt"}) + globs, err = compileGlobs([]string{"*.txt"}) r.NoError(err) - r.True(format.PathMatches("test/foo/bar.txt", globs)) - r.False(format.PathMatches("test/foo/bar.txtz", globs)) - r.False(format.PathMatches("test/foo/bar.flob", globs)) + r.True(pathMatches("test/foo/bar.txt", globs)) + r.False(pathMatches("test/foo/bar.txtz", globs)) + r.False(pathMatches("test/foo/bar.flob", globs)) // Prefix matching - globs, err = format.CompileGlobs([]string{"test/*"}) + globs, err = compileGlobs([]string{"test/*"}) r.NoError(err) - r.True(format.PathMatches("test/bar.txt", globs)) - r.True(format.PathMatches("test/foo/bar.txt", globs)) - r.False(format.PathMatches("/test/foo/bar.txt", globs)) + r.True(pathMatches("test/bar.txt", globs)) + r.True(pathMatches("test/foo/bar.txt", globs)) + r.False(pathMatches("/test/foo/bar.txt", globs)) // Exact matches // File extension - globs, err = format.CompileGlobs([]string{"LICENSE"}) + globs, err = compileGlobs([]string{"LICENSE"}) r.NoError(err) - r.True(format.PathMatches("LICENSE", globs)) - r.False(format.PathMatches("test/LICENSE", globs)) - r.False(format.PathMatches("LICENSE.txt", globs)) + r.True(pathMatches("LICENSE", globs)) + r.False(pathMatches("test/LICENSE", globs)) + r.False(pathMatches("LICENSE.txt", globs)) } diff --git a/walk/cached.go b/walk/cached.go index ce451f20..cfecc902 100644 --- a/walk/cached.go +++ b/walk/cached.go @@ -13,6 +13,18 @@ import ( "golang.org/x/sync/errgroup" ) +type ctxKeyNoCache struct{} + +func SetNoCache(ctx context.Context, noCache bool) context.Context { + return context.WithValue(ctx, ctxKeyNoCache{}, noCache) +} + +func GetNoCache(ctx context.Context) bool { + noCache, ok := ctx.Value(ctxKeyNoCache{}).(bool) + + return ok && noCache +} + // CachedReader reads files from a delegate Reader, appending a cache Entry on read (if on exists) and updating the // cache after the file has been processed. type CachedReader struct { @@ -101,10 +113,8 @@ func (c *CachedReader) Read(ctx context.Context, files []*File) (n int, err erro } // set a release function which inserts this file into the update channel - file.AddReleaseFunc(func(formatErr error) error { - // in the event of a formatting error, we do not want to update this file in the cache - // this ensures later invocations will try and re-format this file - if formatErr == nil { + file.AddReleaseFunc(func(ctx context.Context) error { + if !GetNoCache(ctx) { c.updateCh <- file } diff --git a/walk/cached_test.go b/walk/cached_test.go index 66f1eca2..d74f981f 100644 --- a/walk/cached_test.go +++ b/walk/cached_test.go @@ -50,7 +50,7 @@ func TestCachedReader(t *testing.T) { changeCount++ } - as.NoError(file.Release(nil)) + as.NoError(file.Release(ctx)) } cancel() diff --git a/walk/stdin.go b/walk/stdin.go index a74d82e0..72dda174 100644 --- a/walk/stdin.go +++ b/walk/stdin.go @@ -54,23 +54,20 @@ func (s StdinReader) Read(_ context.Context, files []*File) (n int, err error) { } // dump the temp file to stdout and remove it once the file is finished being processed - files[0].AddReleaseFunc(func(formatErr error) error { - // if formatting was successful, we dump its contents into os.Stdout - if formatErr == nil { - // open the temp file - file, err := os.Open(file.Name()) - if err != nil { - return fmt.Errorf("failed to open temp file %s: %w", file.Name(), err) - } - - // dump file into stdout - if _, err = io.Copy(os.Stdout, file); err != nil { - return fmt.Errorf("failed to copy %s to stdout: %w", file.Name(), err) - } - - if err = file.Close(); err != nil { - return fmt.Errorf("failed to close temp file %s: %w", file.Name(), err) - } + files[0].AddReleaseFunc(func(_ context.Context) error { + // open the temp file + file, err := os.Open(file.Name()) + if err != nil { + return fmt.Errorf("failed to open temp file %s: %w", file.Name(), err) + } + + // dump file into stdout + if _, err = io.Copy(os.Stdout, file); err != nil { + return fmt.Errorf("failed to copy %s to stdout: %w", file.Name(), err) + } + + if err = file.Close(); err != nil { + return fmt.Errorf("failed to close temp file %s: %w", file.Name(), err) } // clean up the temp file diff --git a/walk/walk.go b/walk/walk.go index 4109b442..9480563f 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -27,7 +27,7 @@ const ( BatchSize = 1024 ) -type ReleaseFunc func(formatErr error) error +type ReleaseFunc func(ctx context.Context) error // File represents a file object with its path, relative path, file info, and potential cache entry. type File struct { @@ -42,10 +42,10 @@ type File struct { } // Release calls all registered release functions for the File and returns an error if any function fails. -// Accepts formatErr, which indicates if an error occurred when formatting this file. -func (f *File) Release(formatErr error) error { +// Accepts a context which can be used to pass parameters to the release hooks. +func (f *File) Release(ctx context.Context) error { for _, fn := range f.releaseFuncs { - if err := fn(formatErr); err != nil { + if err := fn(ctx); err != nil { return err } }