diff --git a/pkg/file/tarutil.go b/pkg/file/tarutil.go index 374c60f6..2b3ed217 100644 --- a/pkg/file/tarutil.go +++ b/pkg/file/tarutil.go @@ -6,8 +6,10 @@ import ( "io" "os" "path/filepath" + "strings" "github.com/pkg/errors" + "github.com/spf13/afero" "github.com/anchore/stereoscope/internal/log" ) @@ -124,40 +126,61 @@ func MetadataFromTar(reader io.ReadCloser, tarPath string) (Metadata, error) { return *metadata, nil } -// UntarToDirectory writes the contents of the given tar reader to the given destination +// UntarToDirectory writes the contents of the given tar reader to the given destination. Note: this is meant to handle +// archives for images (not image contents) thus intentionally does not handle links or any kinds of special files. func UntarToDirectory(reader io.Reader, dst string) error { - visitor := func(entry TarFileEntry) error { - target := filepath.Join(dst, entry.Header.Name) - - switch entry.Header.Typeflag { - case tar.TypeDir: - if _, err := os.Stat(target); err != nil { - if err := os.MkdirAll(target, 0755); err != nil { - return err - } - } + return IterateTar( + reader, + tarVisitor{ + fs: afero.NewOsFs(), + destination: dst, + }.visit, + ) +} + +type tarVisitor struct { + fs afero.Fs + destination string +} + +func (v tarVisitor) visit(entry TarFileEntry) error { + target := filepath.Join(v.destination, entry.Header.Name) - case tar.TypeReg: - f, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(entry.Header.Mode)) - if err != nil { + // we should not allow for any destination path to be outside of where we are unarchiving to + if !strings.HasPrefix(target, v.destination) { + return fmt.Errorf("potential path traversal attack with entry: %q", entry.Header.Name) + } + + switch entry.Header.Typeflag { + case tar.TypeSymlink, tar.TypeLink: + // we don't handle this is to prevent any potential traversal attacks + log.WithFields("path", entry.Header.Name).Trace("skipping symlink/link entry in image tar") + + case tar.TypeDir: + if _, err := v.fs.Stat(target); err != nil { + if err := v.fs.MkdirAll(target, 0755); err != nil { return err } + } - // limit the reader on each file read to prevent decompression bomb attacks - numBytes, err := io.Copy(f, io.LimitReader(entry.Reader, perFileReadLimit)) - if numBytes >= perFileReadLimit || errors.Is(err, io.EOF) { - return fmt.Errorf("zip read limit hit (potential decompression bomb attack)") - } - if err != nil { - return fmt.Errorf("unable to copy file: %w", err) - } + case tar.TypeReg: + f, err := v.fs.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(entry.Header.Mode)) + if err != nil { + return err + } - if err = f.Close(); err != nil { - log.Errorf("failed to close file during untar of path=%q: %w", f.Name(), err) - } + // limit the reader on each file read to prevent decompression bomb attacks + numBytes, err := io.Copy(f, io.LimitReader(entry.Reader, perFileReadLimit)) + if numBytes >= perFileReadLimit || errors.Is(err, io.EOF) { + return fmt.Errorf("zip read limit hit (potential decompression bomb attack)") + } + if err != nil { + return fmt.Errorf("unable to copy file: %w", err) } - return nil - } - return IterateTar(reader, visitor) + if err = f.Close(); err != nil { + log.Errorf("failed to close file during untar of path=%q: %w", f.Name(), err) + } + } + return nil } diff --git a/pkg/file/tarutil_test.go b/pkg/file/tarutil_test.go index 036c0c5c..9e6032a0 100644 --- a/pkg/file/tarutil_test.go +++ b/pkg/file/tarutil_test.go @@ -4,6 +4,7 @@ package file import ( + "archive/tar" "crypto/sha256" "fmt" "io" @@ -11,10 +12,16 @@ import ( "os/exec" "path" "path/filepath" + "sort" + "strings" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/scylladb/go-set/strset" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -181,3 +188,217 @@ func fileExists(t testing.TB, filename string) bool { } return !info.IsDir() } + +func Test_tarVisitor_visit(t *testing.T) { + assertNoFilesInRoot := func(t testing.TB, fs afero.Fs) { + t.Helper() + + allowableFiles := strset.New("tmp") + + // list all files in root + files, err := afero.ReadDir(fs, "/") + require.NoError(t, err) + + for _, f := range files { + assert.True(t, allowableFiles.Has(f.Name()), "unexpected file in root: %s", f.Name()) + } + } + + assertPaths := func(expectedFiles []string, expectedDirs []string) func(t testing.TB, fs afero.Fs) { + return func(t testing.TB, fs afero.Fs) { + t.Helper() + + sort.Strings(expectedFiles) + haveFiles := strset.New() + haveDirs := strset.New() + err := afero.Walk(fs, "/", func(path string, info os.FileInfo, err error) error { + require.NoError(t, err) + if info.IsDir() { + haveDirs.Add(path) + } else { + haveFiles.Add(path) + } + return nil + }) + + haveFilesList := haveFiles.List() + sort.Strings(haveFilesList) + + haveDirsList := haveDirs.List() + sort.Strings(haveDirsList) + + require.NoError(t, err) + + if d := cmp.Diff(expectedFiles, haveFilesList); d != "" { + t.Errorf("unexpected files (-want +got):\n%s", d) + } + + if d := cmp.Diff(expectedDirs, haveDirsList); d != "" { + t.Errorf("unexpected dirs (-want +got):\n%s", d) + } + + } + } + + tests := []struct { + name string + entry TarFileEntry + wantErr require.ErrorAssertionFunc + assertFs []func(t testing.TB, fs afero.Fs) + }{ + { + name: "regular file is written", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeReg, + Name: "file.txt", + Linkname: "", + Size: 2, + }, + Reader: strings.NewReader("hi"), + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{"/tmp/file.txt"}, + []string{"/", "/tmp"}, + ), + }, + }, + { + name: "regular file with possible path traversal errors out", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeReg, + Name: "../file.txt", + Linkname: "", + Size: 2, + }, + Reader: strings.NewReader("hi"), + }, + wantErr: require.Error, + }, + { + name: "directory is created", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeDir, + Name: "dir", + Linkname: "", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/", "/tmp", "/tmp/dir"}, + ), + }, + }, + { + name: "symlink is ignored", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeSymlink, + Name: "symlink", + Linkname: "./../to-location", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/"}, + ), + }, + }, + { + name: "hardlink is ignored", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeLink, + Name: "link", + Linkname: "./../to-location", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/"}, + ), + }, + }, + { + name: "device is ignored", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeChar, + Name: "device", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/"}, + ), + }, + }, + { + name: "block device is ignored", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeBlock, + Name: "device", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/"}, + ), + }, + }, + { + name: "pipe is ignored", + entry: TarFileEntry{ + Sequence: 0, + Header: tar.Header{ + Typeflag: tar.TypeFifo, + Name: "pipe", + }, + }, + assertFs: []func(t testing.TB, fs afero.Fs){ + assertPaths( + []string{}, + []string{"/"}, + ), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantErr == nil { + tt.wantErr = require.NoError + } + v := tarVisitor{ + fs: afero.NewMemMapFs(), + destination: "/tmp", + } + err := v.visit(tt.entry) + tt.wantErr(t, err) + if err != nil { + return + } + for _, fn := range tt.assertFs { + fn(t, v.fs) + } + + // even if the test has no other assertions, check that the root is empty + assertNoFilesInRoot(t, v.fs) + }) + } +}