Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,27 @@ import (
"github.com/chainguard-dev/malcontent/pkg/programkind"
)

const (
// 32KB buffer.
bufferSize = 32 * 1024
// 512MB file limit.
maxBytes = 1 << 29
)

// Shared buffer pool for io.CopyBuffer operations.
var bufferPool = sync.Pool{
New: func() interface{} {
b := make([]byte, bufferSize)
return &b
},
}

// isValidPath checks if the target file is within the given directory.
func IsValidPath(target, dir string) bool {
return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir))
}

const maxBytes = 1 << 29 // 512MB

func extractNestedArchive(
ctx context.Context,
d string,
f string,
extracted *sync.Map,
) error {
func extractNestedArchive(ctx context.Context, d string, f string, extracted *sync.Map) error {
isArchive := false
// zlib-compressed files are also archives
ft, err := programkind.File(f)
Expand Down Expand Up @@ -223,6 +231,12 @@ func handleDirectory(target string) error {

// handleFile extracts valid files within .deb or .tar archives.
func handleFile(target string, tr *tar.Reader) error {
buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}
Expand All @@ -233,7 +247,7 @@ func handleFile(target string, tr *tar.Reader) error {
}
defer out.Close()

written, err := io.Copy(out, io.LimitReader(tr, maxBytes))
written, err := io.CopyBuffer(out, io.LimitReader(tr, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/archive/bz2.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ func ExtractBz2(ctx context.Context, d, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting bzip2 file")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

// Check if the file is valid
_, err := os.Stat(f)
if err != nil {
Expand Down Expand Up @@ -53,7 +59,7 @@ func ExtractBz2(ctx context.Context, d, f string) error {
}
defer out.Close()

written, err := io.Copy(out, io.LimitReader(br, maxBytes))
written, err := io.CopyBuffer(out, io.LimitReader(br, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/archive/deb.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ func ExtractDeb(ctx context.Context, d, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting deb")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

fd, err := os.Open(f)
if err != nil {
panic(err)
Expand Down
8 changes: 7 additions & 1 deletion pkg/archive/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting gzip")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

// Check if the file is valid
_, err := os.Stat(f)
if err != nil {
Expand Down Expand Up @@ -59,7 +65,7 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
}
defer out.Close()

written, err := io.Copy(out, io.LimitReader(gr, maxBytes))
written, err := io.CopyBuffer(out, io.LimitReader(gr, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/archive/rpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ func ExtractRPM(ctx context.Context, d, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting rpm")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

rpmFile, err := os.Open(f)
if err != nil {
return fmt.Errorf("failed to open RPM file: %w", err)
Expand Down Expand Up @@ -106,7 +112,7 @@ func ExtractRPM(ctx context.Context, d, f string) error {
return fmt.Errorf("failed to create file: %w", err)
}

written, err := io.Copy(out, io.LimitReader(cr, maxBytes))
written, err := io.CopyBuffer(out, io.LimitReader(cr, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/archive/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ func ExtractTar(ctx context.Context, d string, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting tar")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

// Check if the file is valid
_, err := os.Stat(f)
if err != nil {
Expand Down Expand Up @@ -83,9 +89,13 @@ func ExtractTar(ctx context.Context, d string, f string) error {
}
defer out.Close()

if _, err = io.Copy(out, xzStream); err != nil {
written, err := io.CopyBuffer(out, io.LimitReader(xzStream, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to write decompressed xz output: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}
return nil
case strings.Contains(filename, ".tar.bz2") || strings.Contains(filename, ".tbz"):
br := bzip2.NewReader(tf)
Expand Down
6 changes: 6 additions & 0 deletions pkg/archive/upx.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ func ExtractUPX(ctx context.Context, d, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting upx")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

// Check if the file is valid
_, err := os.Stat(f)
if err != nil {
Expand Down
125 changes: 74 additions & 51 deletions pkg/archive/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,86 +7,109 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"

"github.com/chainguard-dev/clog"
"golang.org/x/sync/errgroup"
)

// extractZip extracts .jar and .zip archives.
// ExtractZip extracts .jar and .zip archives.
func ExtractZip(ctx context.Context, d string, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting zip")

// Check if the file is valid
_, err := os.Stat(f)
fi, err := os.Stat(f)
if err != nil {
return fmt.Errorf("failed to stat file %s: %w", f, err)
}
if fi.Size() == 0 {
return fmt.Errorf("empty zip file: %s", f)
}

read, err := zip.OpenReader(f)
if err != nil {
return fmt.Errorf("failed to open zip file %s: %w", f, err)
}
defer read.Close()

if err := os.MkdirAll(d, 0o700); err != nil {
return fmt.Errorf("failed to create extraction directory: %w", err)
}

g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))

for _, file := range read.File {
clean := filepath.Clean(filepath.ToSlash(file.Name))
if strings.Contains(clean, "..") {
logger.Warnf("skipping potentially unsafe file path: %s", file.Name)
continue
}
g.Go(func() error {
return extractFile(gCtx, file, d, logger)
})
}

target := filepath.Join(d, clean)
if !IsValidPath(target, d) {
logger.Warnf("skipping file path outside extraction directory: %s", target)
continue
}
if err := g.Wait(); err != nil {
return fmt.Errorf("extraction failed: %w", err)
}
return nil
}

// Check if a directory with the same name exists
if info, err := os.Stat(target); err == nil && info.IsDir() {
continue
}
func extractFile(ctx context.Context, file *zip.File, destDir string, logger *clog.Logger) error {
buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

if file.Mode().IsDir() {
err := os.MkdirAll(target, 0o700)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
continue
}
clean := filepath.Clean(filepath.ToSlash(file.Name))
if strings.Contains(clean, "..") {
logger.Warnf("skipping potentially unsafe file path: %s", file.Name)
return nil
}

zf, err := file.Open()
if err != nil {
return fmt.Errorf("failed to open file in zip: %w", err)
}
target := filepath.Join(destDir, clean)
if !IsValidPath(target, destDir) {
logger.Warnf("skipping file path outside extraction directory: %s", target)
return nil
}

err = os.MkdirAll(filepath.Dir(target), 0o700)
if err != nil {
zf.Close()
return fmt.Errorf("failed to create directory: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}

out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
out.Close()
return fmt.Errorf("failed to create file: %w", err)
}
if file.Mode().IsDir() {
return os.MkdirAll(target, 0o700)
}

written, err := io.Copy(out, io.LimitReader(zf, maxBytes))
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return fmt.Errorf("failed to create directory structure: %w", err)
}

if err := out.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
src, err := file.Open()
if err != nil {
return fmt.Errorf("failed to open archived file: %w", err)
}
defer src.Close()

if err := zf.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
dst, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}

var closeErr error
defer func() {
if cerr := dst.Close(); cerr != nil && closeErr == nil {
closeErr = cerr
}
}()

written, err := io.CopyBuffer(dst, io.LimitReader(src, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
}
return nil
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}

return closeErr
}
8 changes: 7 additions & 1 deletion pkg/archive/zlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ func ExtractZlib(ctx context.Context, d string, f string) error {
logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debugf("extracting zlib")

buf, ok := bufferPool.Get().(*[]byte)
if !ok {
return fmt.Errorf("failed to retrieve buffer")
}
defer bufferPool.Put(buf)

// Check if the file is valid
_, err := os.Stat(f)
if err != nil {
Expand Down Expand Up @@ -43,7 +49,7 @@ func ExtractZlib(ctx context.Context, d string, f string) error {
}
defer out.Close()

written, err := io.Copy(out, io.LimitReader(zr, maxBytes))
written, err := io.CopyBuffer(out, io.LimitReader(zr, maxBytes), *buf)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
Expand Down
Loading