Skip to content

Commit

Permalink
feat: improve paths handling
Browse files Browse the repository at this point in the history
Changes how we handle paths which are provided as arguments:

- if it's a file, we attempt to format it based on the matching options
- if it's a directory, we traverse it using the provided walk type

This is more consistent with what a user expects and makes it easier to integrate with tools such as `none-ls`.

Close #435
  • Loading branch information
brianmcgee committed Oct 12, 2024
1 parent 21fbdc0 commit 57abb78
Show file tree
Hide file tree
Showing 12 changed files with 392 additions and 211 deletions.
81 changes: 20 additions & 61 deletions cmd/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,49 +63,6 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
<-time.After(time.Until(startAfter))
}

if cfg.Stdin {
// check we have only received one path arg which we use for the file extension / matching to formatters
if len(paths) != 1 {
return fmt.Errorf("exactly one path should be specified when using the --stdin flag")
}

// read stdin into a temporary file with the same file extension
pattern := fmt.Sprintf("*%s", filepath.Ext(paths[0]))

file, err := os.CreateTemp("", pattern)
if err != nil {
return fmt.Errorf("failed to create a temporary file for processing stdin: %w", err)
}

if _, err = io.Copy(file, os.Stdin); err != nil {
return fmt.Errorf("failed to copy stdin into a temporary file")
}

// set the tree root to match the temp directory
cfg.TreeRoot, err = filepath.Abs(filepath.Dir(file.Name()))
if err != nil {
return fmt.Errorf("failed to get absolute path for tree root: %w", err)
}

// configure filesystem walker to traverse the temporary tree root
cfg.Walk = "filesystem"

// update paths with temp file
paths[0], err = filepath.Rel(os.TempDir(), file.Name())
if err != nil {
return fmt.Errorf("failed to get relative path for temp file: %w", err)
}

} else {
// checks all paths are contained within the tree root
for _, path := range paths {
rootPath := filepath.Join(cfg.TreeRoot, path)
if _, err = os.Stat(rootPath); err != nil {
return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot)
}
}
}

// cpu profiling
if cfg.CpuProfile != "" {
cpuProfile, err := os.Create(cfg.CpuProfile)
Expand Down Expand Up @@ -204,13 +161,29 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
eg.Go(postProcessing(ctx, cfg, statz, formattedCh))
eg.Go(applyFormatters(ctx, cfg, statz, globalExcludes, formatters, filesCh, formattedCh))

//
// parse the walk type
walkType, err := walk.TypeString(cfg.Walk)
if err != nil {
return fmt.Errorf("invalid walk type: %w", err)
}

reader, err := walk.NewReader(walkType, cfg.TreeRoot, paths, db, statz)
if walkType == walk.Stdin {
// check we have only received one path arg which we use for the file extension / matching to formatters
if len(paths) != 1 {
return fmt.Errorf("exactly one path should be specified when using the --stdin flag")
}
} else {
// checks all paths are contained within the tree root
for _, path := range paths {
rootPath := filepath.Join(cfg.TreeRoot, path)
if _, err = os.Stat(rootPath); err != nil {
return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot)
}
}
}

// create a new reader for traversing the paths
reader, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz)
if err != nil {
return fmt.Errorf("failed to create walker: %w", err)
}
Expand Down Expand Up @@ -440,22 +413,8 @@ func postProcessing(
file.Info = newInfo
}

if file.Release != nil {
file.Release()
}

if cfg.Stdin {
// dump file into stdout
f, err := os.Open(file.Path)
if err != nil {
return fmt.Errorf("failed to open %s: %w", file.Path, err)
}
if _, err = io.Copy(os.Stdout, f); err != nil {
return fmt.Errorf("failed to copy %s to stdout: %w", file.Path, err)
}
if err = os.Remove(f.Name()); err != nil {
return fmt.Errorf("failed to remove temp file %s: %w", file.Path, err)
}
if err := file.Release(); err != nil {
return fmt.Errorf("failed to release file: %w", err)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,11 @@ func TestGitWorktree(t *testing.T) {

_, statz, err = treefmt(t, "-C", tempDir, "-c", "haskell", "foo.txt")
as.NoError(err)
assertStats(t, as, statz, 7, 7, 7, 0)
assertStats(t, as, statz, 8, 8, 8, 0)

_, statz, err = treefmt(t, "-C", tempDir, "-c", "foo.txt")
as.NoError(err)
assertStats(t, as, statz, 0, 0, 0, 0)
assertStats(t, as, statz, 1, 1, 1, 0)
}

func TestPathsArg(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"path/filepath"
"strings"

"github.com/numtide/treefmt/walk"

"github.com/spf13/pflag"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -175,6 +177,11 @@ func FromViper(v *viper.Viper) (*Config, error) {
return nil, fmt.Errorf("failed to get absolute path for working directory: %w", err)
}

// if the stdin flag was passed, we force the stdin walk type
if cfg.Stdin {
cfg.Walk = walk.Stdin.String()
}

// determine the tree root
if cfg.TreeRoot == "" {
// if none was specified, we first try with tree-root-file
Expand Down
5 changes: 3 additions & 2 deletions walk/cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ func (c *CachedReader) Read(ctx context.Context, files []*File) (n int, err erro
}

// set a release function which inserts this file into the release channel for updating
file.Release = func() {
file.AddReleaseFunc(func() error {
c.releaseCh <- file
}
return nil
})
}

if errors.Is(err, io.EOF) {
Expand Down
31 changes: 21 additions & 10 deletions walk/cached_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ func TestCachedReader(t *testing.T) {
batchSize := 1024
tempDir := test.TempExamples(t)

readAll := func(paths []string) (totalCount, newCount, changeCount int, statz stats.Stats) {
readAll := func(path string) (totalCount, newCount, changeCount int, statz stats.Stats) {
statz = stats.New()

db, err := cache.Open(tempDir)
as.NoError(err)
defer db.Close()

delegate := walk.NewFilesystemReader(tempDir, paths, &statz, batchSize)
delegate := walk.NewFilesystemReader(tempDir, path, &statz, batchSize)
reader, err := walk.NewCachedReader(db, batchSize, delegate)
as.NoError(err)

Expand All @@ -49,7 +49,8 @@ func TestCachedReader(t *testing.T) {
} else if file.Cache.HasChanged(file.Info) {
changeCount++
}
file.Release()

as.NoError(file.Release())
}

cancel()
Expand All @@ -64,13 +65,13 @@ func TestCachedReader(t *testing.T) {
return totalCount, newCount, changeCount, statz
}

totalCount, newCount, changeCount, _ := readAll([]string{"."})
totalCount, newCount, changeCount, _ := readAll("")
as.Equal(32, totalCount)
as.Equal(32, newCount)
as.Equal(0, changeCount)

// read again, should be no changes
totalCount, newCount, changeCount, _ = readAll([]string{"."})
totalCount, newCount, changeCount, _ = readAll("")
as.Equal(32, totalCount)
as.Equal(0, newCount)
as.Equal(0, changeCount)
Expand All @@ -83,7 +84,7 @@ func TestCachedReader(t *testing.T) {
as.NoError(os.Chtimes(filepath.Join(tempDir, "shell/foo.sh"), time.Now(), modTime))
as.NoError(os.Chtimes(filepath.Join(tempDir, "haskell/Nested/Foo.hs"), time.Now(), modTime))

totalCount, newCount, changeCount, _ = readAll([]string{"."})
totalCount, newCount, changeCount, _ = readAll("")
as.Equal(32, totalCount)
as.Equal(0, newCount)
as.Equal(3, changeCount)
Expand All @@ -95,7 +96,7 @@ func TestCachedReader(t *testing.T) {
_, err = os.Create(filepath.Join(tempDir, "fizz.go"))
as.NoError(err)

totalCount, newCount, changeCount, _ = readAll([]string{"."})
totalCount, newCount, changeCount, _ = readAll("")
as.Equal(34, totalCount)
as.Equal(2, newCount)
as.Equal(0, changeCount)
Expand All @@ -113,14 +114,24 @@ func TestCachedReader(t *testing.T) {
as.NoError(err)
as.NoError(f.Close())

totalCount, newCount, changeCount, _ = readAll([]string{"."})
totalCount, newCount, changeCount, _ = readAll("")
as.Equal(34, totalCount)
as.Equal(0, newCount)
as.Equal(2, changeCount)

// read some paths within the root
totalCount, newCount, changeCount, _ = readAll([]string{"go", "elm/src", "haskell"})
as.Equal(10, totalCount)
totalCount, newCount, changeCount, _ = readAll("go")
as.Equal(2, totalCount)
as.Equal(0, newCount)
as.Equal(0, changeCount)

totalCount, newCount, changeCount, _ = readAll("elm/src")
as.Equal(1, totalCount)
as.Equal(0, newCount)
as.Equal(0, changeCount)

totalCount, newCount, changeCount, _ = readAll("haskell")
as.Equal(7, totalCount)
as.Equal(0, newCount)
as.Equal(0, changeCount)
}
44 changes: 15 additions & 29 deletions walk/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
type FilesystemReader struct {
log *log.Logger
root string
paths []string
path string
batchSize int

eg *errgroup.Group
Expand All @@ -35,7 +35,17 @@ func (f *FilesystemReader) process() error {
close(f.filesCh)
}()

walkFn := func(path string, info fs.FileInfo, err error) error {
// f.path is relative to the root, so we create a fully qualified version
// we also clean the path up in case there are any ../../ components etc.
path := filepath.Clean(filepath.Join(f.root, f.path))

// ensure the path is within the root
if !strings.HasPrefix(path, f.root) {
return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root)
}

// walk the path
return filepath.Walk(path, func(path string, info fs.FileInfo, err error) error {
// return errors immediately
if err != nil {
return err
Expand Down Expand Up @@ -64,26 +74,7 @@ func (f *FilesystemReader) process() error {
f.log.Debugf("file queued %s", file.RelPath)

return nil
}

// walk each path specified
for idx := range f.paths {
// f.paths are relative to the root, so we create a fully qualified version
// we also clean the path up in case there are any ../../ components etc.
path := filepath.Clean(filepath.Join(f.root, f.paths[idx]))

// ensure the path is within the root
if !strings.HasPrefix(path, f.root) {
return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root)
}

// walk the path
if err := filepath.Walk(path, walkFn); err != nil {
return err
}
}

return nil
})
}

// Read populates the provided files array with as many files are available until the provided context is cancelled.
Expand Down Expand Up @@ -129,22 +120,17 @@ func (f *FilesystemReader) Close() error {
// and root.
func NewFilesystemReader(
root string,
paths []string,
path string,
statz *stats.Stats,
batchSize int,
) *FilesystemReader {
// if no paths are specified, we default to the root path
if len(paths) == 0 {
paths = []string{"."}
}

// create an error group for managing the processing loop
eg := errgroup.Group{}

r := FilesystemReader{
log: log.WithPrefix("walk[filesystem]"),
root: root,
paths: paths,
path: path,
batchSize: batchSize,

eg: &eg,
Expand Down
2 changes: 1 addition & 1 deletion walk/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestFilesystemReader(t *testing.T) {
tempDir := test.TempExamples(t)
statz := stats.New()

r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024)
r := walk.NewFilesystemReader(tempDir, "", &statz, 1024)

count := 0

Expand Down
Loading

0 comments on commit 57abb78

Please sign in to comment.