diff --git a/cmd/format/format.go b/cmd/format/format.go index 03f06d22..8bfae813 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -63,49 +63,6 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) <-time.After(time.Until(startAfter)) } - if cfg.Stdin { - // check we have only received one path arg which we use for the file extension / matching to formatters - if len(paths) != 1 { - return fmt.Errorf("exactly one path should be specified when using the --stdin flag") - } - - // read stdin into a temporary file with the same file extension - pattern := fmt.Sprintf("*%s", filepath.Ext(paths[0])) - - file, err := os.CreateTemp("", pattern) - if err != nil { - return fmt.Errorf("failed to create a temporary file for processing stdin: %w", err) - } - - if _, err = io.Copy(file, os.Stdin); err != nil { - return fmt.Errorf("failed to copy stdin into a temporary file") - } - - // set the tree root to match the temp directory - cfg.TreeRoot, err = filepath.Abs(filepath.Dir(file.Name())) - if err != nil { - return fmt.Errorf("failed to get absolute path for tree root: %w", err) - } - - // configure filesystem walker to traverse the temporary tree root - cfg.Walk = "filesystem" - - // update paths with temp file - paths[0], err = filepath.Rel(os.TempDir(), file.Name()) - if err != nil { - return fmt.Errorf("failed to get relative path for temp file: %w", err) - } - - } else { - // checks all paths are contained within the tree root - for _, path := range paths { - rootPath := filepath.Join(cfg.TreeRoot, path) - if _, err = os.Stat(rootPath); err != nil { - return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot) - } - } - } - // cpu profiling if cfg.CpuProfile != "" { cpuProfile, err := os.Create(cfg.CpuProfile) @@ -204,13 +161,29 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) eg.Go(postProcessing(ctx, cfg, statz, formattedCh)) eg.Go(applyFormatters(ctx, cfg, statz, globalExcludes, formatters, filesCh, formattedCh)) - // + // parse the walk type walkType, err := walk.TypeString(cfg.Walk) if err != nil { return fmt.Errorf("invalid walk type: %w", err) } - reader, err := walk.NewReader(walkType, cfg.TreeRoot, paths, db, statz) + if walkType == walk.Stdin { + // check we have only received one path arg which we use for the file extension / matching to formatters + if len(paths) != 1 { + return fmt.Errorf("exactly one path should be specified when using the --stdin flag") + } + } else { + // checks all paths are contained within the tree root + for _, path := range paths { + rootPath := filepath.Join(cfg.TreeRoot, path) + if _, err = os.Stat(rootPath); err != nil { + return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot) + } + } + } + + // create a new reader for traversing the paths + reader, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz) if err != nil { return fmt.Errorf("failed to create walker: %w", err) } @@ -440,22 +413,8 @@ func postProcessing( file.Info = newInfo } - if file.Release != nil { - file.Release() - } - - if cfg.Stdin { - // dump file into stdout - f, err := os.Open(file.Path) - if err != nil { - return fmt.Errorf("failed to open %s: %w", file.Path, err) - } - if _, err = io.Copy(os.Stdout, f); err != nil { - return fmt.Errorf("failed to copy %s to stdout: %w", file.Path, err) - } - if err = os.Remove(f.Name()); err != nil { - return fmt.Errorf("failed to remove temp file %s: %w", file.Path, err) - } + if err := file.Release(); err != nil { + return fmt.Errorf("failed to release file: %w", err) } } } diff --git a/cmd/root_test.go b/cmd/root_test.go index 4e6fed8c..a3500a50 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -628,11 +628,11 @@ func TestGitWorktree(t *testing.T) { _, statz, err = treefmt(t, "-C", tempDir, "-c", "haskell", "foo.txt") as.NoError(err) - assertStats(t, as, statz, 7, 7, 7, 0) + assertStats(t, as, statz, 8, 8, 8, 0) _, statz, err = treefmt(t, "-C", tempDir, "-c", "foo.txt") as.NoError(err) - assertStats(t, as, statz, 0, 0, 0, 0) + assertStats(t, as, statz, 1, 1, 1, 0) } func TestPathsArg(t *testing.T) { diff --git a/config/config.go b/config/config.go index 2ea1e775..801a5d01 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,8 @@ import ( "path/filepath" "strings" + "github.com/numtide/treefmt/walk" + "github.com/spf13/pflag" "github.com/spf13/viper" ) @@ -175,6 +177,11 @@ func FromViper(v *viper.Viper) (*Config, error) { return nil, fmt.Errorf("failed to get absolute path for working directory: %w", err) } + // if the stdin flag was passed, we force the stdin walk type + if cfg.Stdin { + cfg.Walk = walk.Stdin.String() + } + // determine the tree root if cfg.TreeRoot == "" { // if none was specified, we first try with tree-root-file diff --git a/walk/cached.go b/walk/cached.go index eff946ee..50a0e82a 100644 --- a/walk/cached.go +++ b/walk/cached.go @@ -95,9 +95,10 @@ func (c *CachedReader) Read(ctx context.Context, files []*File) (n int, err erro } // set a release function which inserts this file into the release channel for updating - file.Release = func() { + file.AddReleaseFunc(func() error { c.releaseCh <- file - } + return nil + }) } if errors.Is(err, io.EOF) { diff --git a/walk/cached_test.go b/walk/cached_test.go index 4f00854c..b26273a2 100644 --- a/walk/cached_test.go +++ b/walk/cached_test.go @@ -22,14 +22,14 @@ func TestCachedReader(t *testing.T) { batchSize := 1024 tempDir := test.TempExamples(t) - readAll := func(paths []string) (totalCount, newCount, changeCount int, statz stats.Stats) { + readAll := func(path string) (totalCount, newCount, changeCount int, statz stats.Stats) { statz = stats.New() db, err := cache.Open(tempDir) as.NoError(err) defer db.Close() - delegate := walk.NewFilesystemReader(tempDir, paths, &statz, batchSize) + delegate := walk.NewFilesystemReader(tempDir, path, &statz, batchSize) reader, err := walk.NewCachedReader(db, batchSize, delegate) as.NoError(err) @@ -49,7 +49,8 @@ func TestCachedReader(t *testing.T) { } else if file.Cache.HasChanged(file.Info) { changeCount++ } - file.Release() + + as.NoError(file.Release()) } cancel() @@ -64,13 +65,13 @@ func TestCachedReader(t *testing.T) { return totalCount, newCount, changeCount, statz } - totalCount, newCount, changeCount, _ := readAll([]string{"."}) + totalCount, newCount, changeCount, _ := readAll("") as.Equal(32, totalCount) as.Equal(32, newCount) as.Equal(0, changeCount) // read again, should be no changes - totalCount, newCount, changeCount, _ = readAll([]string{"."}) + totalCount, newCount, changeCount, _ = readAll("") as.Equal(32, totalCount) as.Equal(0, newCount) as.Equal(0, changeCount) @@ -83,7 +84,7 @@ func TestCachedReader(t *testing.T) { as.NoError(os.Chtimes(filepath.Join(tempDir, "shell/foo.sh"), time.Now(), modTime)) as.NoError(os.Chtimes(filepath.Join(tempDir, "haskell/Nested/Foo.hs"), time.Now(), modTime)) - totalCount, newCount, changeCount, _ = readAll([]string{"."}) + totalCount, newCount, changeCount, _ = readAll("") as.Equal(32, totalCount) as.Equal(0, newCount) as.Equal(3, changeCount) @@ -95,7 +96,7 @@ func TestCachedReader(t *testing.T) { _, err = os.Create(filepath.Join(tempDir, "fizz.go")) as.NoError(err) - totalCount, newCount, changeCount, _ = readAll([]string{"."}) + totalCount, newCount, changeCount, _ = readAll("") as.Equal(34, totalCount) as.Equal(2, newCount) as.Equal(0, changeCount) @@ -113,14 +114,24 @@ func TestCachedReader(t *testing.T) { as.NoError(err) as.NoError(f.Close()) - totalCount, newCount, changeCount, _ = readAll([]string{"."}) + totalCount, newCount, changeCount, _ = readAll("") as.Equal(34, totalCount) as.Equal(0, newCount) as.Equal(2, changeCount) // read some paths within the root - totalCount, newCount, changeCount, _ = readAll([]string{"go", "elm/src", "haskell"}) - as.Equal(10, totalCount) + totalCount, newCount, changeCount, _ = readAll("go") + as.Equal(2, totalCount) + as.Equal(0, newCount) + as.Equal(0, changeCount) + + totalCount, newCount, changeCount, _ = readAll("elm/src") + as.Equal(1, totalCount) + as.Equal(0, newCount) + as.Equal(0, changeCount) + + totalCount, newCount, changeCount, _ = readAll("haskell") + as.Equal(7, totalCount) as.Equal(0, newCount) as.Equal(0, changeCount) } diff --git a/walk/filesystem.go b/walk/filesystem.go index f6c441c2..b27afb5f 100644 --- a/walk/filesystem.go +++ b/walk/filesystem.go @@ -19,7 +19,7 @@ import ( type FilesystemReader struct { log *log.Logger root string - paths []string + path string batchSize int eg *errgroup.Group @@ -35,7 +35,17 @@ func (f *FilesystemReader) process() error { close(f.filesCh) }() - walkFn := func(path string, info fs.FileInfo, err error) error { + // f.path is relative to the root, so we create a fully qualified version + // we also clean the path up in case there are any ../../ components etc. + path := filepath.Clean(filepath.Join(f.root, f.path)) + + // ensure the path is within the root + if !strings.HasPrefix(path, f.root) { + return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root) + } + + // walk the path + return filepath.Walk(path, func(path string, info fs.FileInfo, err error) error { // return errors immediately if err != nil { return err @@ -64,26 +74,7 @@ func (f *FilesystemReader) process() error { f.log.Debugf("file queued %s", file.RelPath) return nil - } - - // walk each path specified - for idx := range f.paths { - // f.paths are relative to the root, so we create a fully qualified version - // we also clean the path up in case there are any ../../ components etc. - path := filepath.Clean(filepath.Join(f.root, f.paths[idx])) - - // ensure the path is within the root - if !strings.HasPrefix(path, f.root) { - return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root) - } - - // walk the path - if err := filepath.Walk(path, walkFn); err != nil { - return err - } - } - - return nil + }) } // Read populates the provided files array with as many files are available until the provided context is cancelled. @@ -129,22 +120,17 @@ func (f *FilesystemReader) Close() error { // and root. func NewFilesystemReader( root string, - paths []string, + path string, statz *stats.Stats, batchSize int, ) *FilesystemReader { - // if no paths are specified, we default to the root path - if len(paths) == 0 { - paths = []string{"."} - } - // create an error group for managing the processing loop eg := errgroup.Group{} r := FilesystemReader{ log: log.WithPrefix("walk[filesystem]"), root: root, - paths: paths, + path: path, batchSize: batchSize, eg: &eg, diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index 02775a4d..8b1a2e86 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -56,7 +56,7 @@ func TestFilesystemReader(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) + r := walk.NewFilesystemReader(tempDir, "", &statz, 1024) count := 0 diff --git a/walk/git.go b/walk/git.go index 3f29ca46..91d13d72 100644 --- a/walk/git.go +++ b/walk/git.go @@ -19,7 +19,7 @@ import ( type GitReader struct { root string - paths []string + path string stats *stats.Stats batchSize int @@ -45,103 +45,100 @@ func (g *GitReader) process() error { // git index into memory for faster lookups var idxCache *filetree - for pathIdx := range g.paths { + path := filepath.Clean(filepath.Join(g.root, g.path)) + if !strings.HasPrefix(path, g.root) { + return fmt.Errorf("path '%s' is outside of the root '%s'", path, g.root) + } - path := filepath.Clean(filepath.Join(g.root, g.paths[pathIdx])) - if !strings.HasPrefix(path, g.root) { - return fmt.Errorf("path '%s' is outside of the root '%s'", path, g.root) - } + switch path { - switch path { + case g.root: - case g.root: + // we can just iterate the index entries + for _, entry := range gitIndex.Entries { - // we can just iterate the index entries - for _, entry := range gitIndex.Entries { + // we only want regular files, not directories or symlinks + if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink { + continue + } - // we only want regular files, not directories or symlinks - if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink { - continue - } + // stat the file + path := filepath.Join(g.root, entry.Name) - // stat the file - path := filepath.Join(g.root, entry.Name) + info, err := os.Lstat(path) + if os.IsNotExist(err) { + // the underlying file might have been removed without the change being staged yet + g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path) + continue + } else if err != nil { + return fmt.Errorf("failed to stat %s: %w", path, err) + } - info, err := os.Lstat(path) - if os.IsNotExist(err) { - // the underlying file might have been removed without the change being staged yet - g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path) - continue - } else if err != nil { - return fmt.Errorf("failed to stat %s: %w", path, err) - } + // determine a relative path + relPath, err := filepath.Rel(g.root, path) + if err != nil { + return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) + } - // determine a relative path - relPath, err := filepath.Rel(g.root, path) - if err != nil { - return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) - } + file := File{ + Path: path, + RelPath: relPath, + Info: info, + } - file := File{ - Path: path, - RelPath: relPath, - Info: info, - } + g.stats.Add(stats.Traversed, 1) + g.filesCh <- &file + } - g.stats.Add(stats.Traversed, 1) - g.filesCh <- &file - } + default: - default: + // read the git index into memory if it hasn't already + if idxCache == nil { + idxCache = &filetree{name: ""} + idxCache.readIndex(gitIndex) + } - // read the git index into memory if it hasn't already - if idxCache == nil { - idxCache = &filetree{name: ""} - idxCache.readIndex(gitIndex) + // git index entries are relative to the repository root, so we need to determine a relative path for the + // one we are currently processing before checking if it exists within the git index + relPath, err := filepath.Rel(g.root, path) + if err != nil { + return fmt.Errorf("failed to find root relative path for %v: %w", path, err) + } + + if !idxCache.hasPath(relPath) { + log.Debugf("path %s not found in git index, skipping", relPath) + return nil + } + + err = filepath.Walk(path, func(path string, info fs.FileInfo, _ error) error { + // skip directories + if info.IsDir() { + return nil } - // git index entries are relative to the repository root, so we need to determine a relative path for the - // one we are currently processing before checking if it exists within the git index + // determine a path relative to g.root before checking presence in the git index relPath, err := filepath.Rel(g.root, path) if err != nil { - return fmt.Errorf("failed to find root relative path for %v: %w", path, err) + return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) } if !idxCache.hasPath(relPath) { - log.Debugf("path %s not found in git index, skipping", relPath) - continue + log.Debugf("path %v not found in git index, skipping", relPath) + return nil } - err = filepath.Walk(path, func(path string, info fs.FileInfo, _ error) error { - // skip directories - if info.IsDir() { - return nil - } - - // determine a path relative to g.root before checking presence in the git index - relPath, err := filepath.Rel(g.root, path) - if err != nil { - return fmt.Errorf("failed to determine a relative path for %s: %w", path, err) - } - - if !idxCache.hasPath(relPath) { - log.Debugf("path %v not found in git index, skipping", relPath) - return nil - } - - file := File{ - Path: path, - RelPath: relPath, - Info: info, - } - - g.stats.Add(stats.Traversed, 1) - g.filesCh <- &file - return nil - }) - if err != nil { - return fmt.Errorf("failed to walk %s: %w", path, err) + file := File{ + Path: path, + RelPath: relPath, + Info: info, } + + g.stats.Add(stats.Traversed, 1) + g.filesCh <- &file + return nil + }) + if err != nil { + return fmt.Errorf("failed to walk %s: %w", path, err) } } @@ -175,7 +172,7 @@ func (g *GitReader) Close() error { func NewGitReader( root string, - paths []string, + path string, statz *stats.Stats, batchSize int, ) (*GitReader, error) { @@ -186,13 +183,9 @@ func NewGitReader( eg := &errgroup.Group{} - if len(paths) == 0 { - paths = []string{"."} - } - r := &GitReader{ root: root, - paths: paths, + path: path, stats: statz, batchSize: batchSize, log: log.WithPrefix("walk[git]"), diff --git a/walk/git_test.go b/walk/git_test.go index 66dc645b..d7ed418f 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -40,7 +40,7 @@ func TestGitReader(t *testing.T) { statz := stats.New() - reader, err := walk.NewGitReader(tempDir, nil, &statz, 1024) + reader, err := walk.NewGitReader(tempDir, "", &statz, 1024) as.NoError(err) count := 0 diff --git a/walk/stdin.go b/walk/stdin.go new file mode 100644 index 00000000..385e60e0 --- /dev/null +++ b/walk/stdin.go @@ -0,0 +1,97 @@ +package walk + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/numtide/treefmt/stats" +) + +type StdinReader struct { + root string + path string + stats stats.Stats + input *os.File + + complete bool +} + +func (s StdinReader) Read(_ context.Context, files []*File) (n int, err error) { + if s.complete { + return 0, io.EOF + } + + // read stdin into a temporary file with the same file extension + pattern := fmt.Sprintf("*%s", filepath.Ext(s.path)) + + file, err := os.CreateTemp(s.root, pattern) + if err != nil { + return 0, fmt.Errorf("failed to create a temporary file for processing stdin: %w", err) + } + defer file.Close() + + if _, err = io.Copy(file, s.input); err != nil { + return 0, fmt.Errorf("failed to copy stdin into a temporary file") + } + + info, err := file.Stat() + if err != nil { + return 0, fmt.Errorf("failed to get file info for temporary file: %w", err) + } + + relPath, err := filepath.Rel(s.root, file.Name()) + if err != nil { + return 0, fmt.Errorf("failed to get relative path for temporary file: %w", err) + } + + files[0] = &File{ + Path: file.Name(), + RelPath: relPath, + Info: info, + } + + // dump the temp file to stdout and remove it once the file is finished being processed + files[0].AddReleaseFunc(func() error { + // open the temp file + file, err := os.Open(file.Name()) + if err != nil { + return fmt.Errorf("failed to open temp file %s: %w", file.Name(), err) + } + + // dump file into stdout + if _, err = io.Copy(os.Stdout, file); err != nil { + return fmt.Errorf("failed to copy %s to stdout: %w", file.Name(), err) + } + + if err = file.Close(); err != nil { + return fmt.Errorf("failed to close temp file %s: %w", file.Name(), err) + } + + if err = os.Remove(file.Name()); err != nil { + return fmt.Errorf("failed to remove temp file %s: %w", file.Name(), err) + } + + return nil + }) + + s.complete = true + s.stats.Add(stats.Traversed, 1) + + return 1, io.EOF +} + +func (s StdinReader) Close() error { + return nil +} + +func NewStdinReader(root string, path string, statz *stats.Stats) Reader { + return StdinReader{ + root: root, + path: path, + stats: *statz, + input: os.Stdin, + } +} diff --git a/walk/type_enum.go b/walk/type_enum.go index f888fee8..ba377ab9 100644 --- a/walk/type_enum.go +++ b/walk/type_enum.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _TypeName = "autogitfilesystem" +const _TypeName = "autogitfilesystemstdin" -var _TypeIndex = [...]uint8{0, 4, 7, 17} +var _TypeIndex = [...]uint8{0, 4, 7, 17, 22} -const _TypeLowerName = "autogitfilesystem" +const _TypeLowerName = "autogitfilesystemstdin" func (i Type) String() string { if i < 0 || i >= Type(len(_TypeIndex)-1) { @@ -27,23 +27,27 @@ func _TypeNoOp() { _ = x[Auto-(0)] _ = x[Git-(1)] _ = x[Filesystem-(2)] + _ = x[Stdin-(3)] } -var _TypeValues = []Type{Auto, Git, Filesystem} +var _TypeValues = []Type{Auto, Git, Filesystem, Stdin} var _TypeNameToValueMap = map[string]Type{ - _TypeName[0:4]: Auto, - _TypeLowerName[0:4]: Auto, - _TypeName[4:7]: Git, - _TypeLowerName[4:7]: Git, - _TypeName[7:17]: Filesystem, - _TypeLowerName[7:17]: Filesystem, + _TypeName[0:4]: Auto, + _TypeLowerName[0:4]: Auto, + _TypeName[4:7]: Git, + _TypeLowerName[4:7]: Git, + _TypeName[7:17]: Filesystem, + _TypeLowerName[7:17]: Filesystem, + _TypeName[17:22]: Stdin, + _TypeLowerName[17:22]: Stdin, } var _TypeNames = []string{ _TypeName[0:4], _TypeName[4:7], _TypeName[7:17], + _TypeName[17:22], } // TypeString retrieves an enum value from the enum constants string name. diff --git a/walk/walk.go b/walk/walk.go index 86c3c8ee..7805a438 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -2,9 +2,12 @@ package walk import ( "context" + "errors" "fmt" + "io" "io/fs" "os" + "path/filepath" "time" "github.com/numtide/treefmt/stats" @@ -19,12 +22,14 @@ const ( Auto Type = iota Git Filesystem + Stdin BatchSize = 1024 ) -// File represents a file object with its path, relative path, file info, and potential cached entry. -// It provides an optional release function to trigger a cache update after processing. +type ReleaseFunc func() error + +// File represents a file object with its path, relative path, file info, and potential cache entry. type File struct { Path string RelPath string @@ -33,14 +38,28 @@ type File struct { // Cache is the latest entry found for this file, if one exists. Cache *cache.Entry - // An optional function to be invoked when this File has finished processing. - // Typically used to trigger a cache update. - Release func() + releaseFuncs []ReleaseFunc +} + +// Release invokes all registered release functions for the File. +// If any release function returns an error, Release stops and returns that error. +func (f *File) Release() error { + for _, fn := range f.releaseFuncs { + if err := fn(); err != nil { + return err + } + } + return nil +} + +// AddReleaseFunc adds a release function to the File's list of release functions. +func (f *File) AddReleaseFunc(fn ReleaseFunc) { + f.releaseFuncs = append(f.releaseFuncs, fn) } // Stat checks if the file has changed by comparing its current state (size, mod time) to when it was first read. // It returns a boolean indicating if the file has changed, the current file info, and an error if any. -func (f File) Stat() (bool, fs.FileInfo, error) { +func (f *File) Stat() (changed bool, info fs.FileInfo, err error) { // Get the file's current state current, err := os.Stat(f.Path) if err != nil { @@ -64,7 +83,7 @@ func (f File) Stat() (bool, fs.FileInfo, error) { } // String returns the file's path as a string. -func (f File) String() string { +func (f *File) String() string { return f.Path } @@ -74,11 +93,56 @@ type Reader interface { Close() error } -// NewReader creates a new instance of Reader based on the given walkType (Auto, Git, Filesystem). +// CompositeReader combines multiple Readers into one. +// It iterates over the given readers, reading each until completion. +type CompositeReader struct { + idx int + current Reader + readers []Reader +} + +func (c *CompositeReader) Read(ctx context.Context, files []*File) (n int, err error) { + if c.current == nil { + // check if we have exhausted all the readers + if c.idx >= len(c.readers) { + return 0, io.EOF + } + + // if not, select the next reader + c.current = c.readers[c.idx] + c.idx++ + } + + // attempt a read + n, err = c.current.Read(ctx, files) + + // check if the current reader has been exhausted + if errors.Is(err, io.EOF) { + // reset the error if it's EOF + err = nil + // set the current reader to nil so we try to read from the next reader on the next call + c.current = nil + } else if err != nil { + err = fmt.Errorf("failed to read from current reader: %w", err) + } + + // return the number of files read in this call and any error + return n, err +} + +func (c *CompositeReader) Close() error { + for _, reader := range c.readers { + if err := reader.Close(); err != nil { + return fmt.Errorf("failed to close reader: %w", err) + } + } + return nil +} + func NewReader( walkType Type, root string, - paths []string, + path string, db *bolt.DB, statz *stats.Stats, ) (Reader, error) { @@ -90,15 +154,17 @@ func NewReader( switch walkType { case Auto: // for now, we keep it simple and try git first, filesystem second - reader, err = NewReader(Git, root, paths, db, statz) + reader, err = NewReader(Git, root, path, db, statz) if err != nil { - reader, err = NewReader(Filesystem, root, paths, db, statz) + reader, err = NewReader(Filesystem, root, path, db, statz) } return reader, err case Git: - reader, err = NewGitReader(root, paths, statz, BatchSize) + reader, err = NewGitReader(root, path, statz, BatchSize) case Filesystem: - reader = NewFilesystemReader(root, paths, statz, BatchSize) + reader = NewFilesystemReader(root, path, statz, BatchSize) + case Stdin: + return nil, fmt.Errorf("stdin walk type is not supported") default: return nil, fmt.Errorf("unknown walk type: %v", walkType) } @@ -115,3 +181,60 @@ func NewReader( return reader, err } + +func NewCompositeReader( + walkType Type, + root string, + paths []string, + db *bolt.DB, + statz *stats.Stats, +) (Reader, error) { + // if not paths are provided we default to processing the tree root + if len(paths) == 0 { + return NewReader(walkType, root, "", db, statz) + } + + readers := make([]Reader, len(paths)) + + // check we have received 1 path for the stdin walk type + if walkType == Stdin { + if len(paths) != 1 { + return nil, fmt.Errorf("stdin walk requires exactly one path") + } + + return NewStdinReader(root, paths[0], statz), nil + } + + // create a reader for each provided path + for idx, relPath := range paths { + var ( + err error + info os.FileInfo + ) + + // create a clean absolute path + path := filepath.Clean(filepath.Join(root, relPath)) + + // check the path exists + info, err = os.Lstat(path) + if err != nil { + return nil, fmt.Errorf("failed to stat %s: %w", path, err) + } + + if info.IsDir() { + // for directories, we honour the walk type as we traverse them + readers[idx], err = NewReader(walkType, root, relPath, db, statz) + } else { + // for files, we enforce a simple filesystem read + readers[idx], err = NewReader(Filesystem, root, relPath, db, statz) + } + + if err != nil { + return nil, fmt.Errorf("failed to create reader for %s: %w", relPath, err) + } + } + + return &CompositeReader{ + readers: readers, + }, nil +}