From f5d0f429e58f58c6f003af4d77a7015de468f7d9 Mon Sep 17 00:00:00 2001 From: egibs <20933572+egibs@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:38:57 -0600 Subject: [PATCH 1/4] Improve efficiency and performance of zip extractions Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/archive/zip.go | 134 ++++++++++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 51 deletions(-) diff --git a/pkg/archive/zip.go b/pkg/archive/zip.go index 619ae6dde..02d65b157 100644 --- a/pkg/archive/zip.go +++ b/pkg/archive/zip.go @@ -7,21 +7,34 @@ import ( "io" "os" "path/filepath" + "runtime" "strings" + "sync" "github.com/chainguard-dev/clog" + "golang.org/x/sync/errgroup" ) -// extractZip extracts .jar and .zip archives. +const bufferSize = 32 * 1024 + +var bufferPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, bufferSize) + return &b + }, +} + func ExtractZip(ctx context.Context, d string, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting zip") - // Check if the file is valid - _, err := os.Stat(f) + fi, err := os.Stat(f) if err != nil { return fmt.Errorf("failed to stat file %s: %w", f, err) } + if fi.Size() == 0 { + return fmt.Errorf("empty zip file: %s", f) + } read, err := zip.OpenReader(f) if err != nil { @@ -29,64 +42,83 @@ func ExtractZip(ctx context.Context, d string, f string) error { } defer read.Close() + if err := os.MkdirAll(d, 0o700); err != nil { + return fmt.Errorf("failed to create extraction directory: %w", err) + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) + for _, file := range read.File { - clean := filepath.Clean(filepath.ToSlash(file.Name)) - if strings.Contains(clean, "..") { - logger.Warnf("skipping potentially unsafe file path: %s", file.Name) - continue - } + g.Go(func() error { + return extractFile(ctx, file, d, logger) + }) + } - target := filepath.Join(d, clean) - if !IsValidPath(target, d) { - logger.Warnf("skipping file path outside extraction directory: %s", target) - continue - } + if err := g.Wait(); err != nil { + return fmt.Errorf("extraction failed: %w", err) + } + return nil +} - // Check if a directory with the same name exists - if info, err := os.Stat(target); err == nil && info.IsDir() { - continue - } +func extractFile(ctx context.Context, file *zip.File, destDir string, logger *clog.Logger) error { + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) - if file.Mode().IsDir() { - err := os.MkdirAll(target, 0o700) - if err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - continue - } + clean := filepath.Clean(filepath.ToSlash(file.Name)) + if strings.Contains(clean, "..") { + logger.Warnf("skipping potentially unsafe file path: %s", file.Name) + return nil + } - zf, err := file.Open() - if err != nil { - return fmt.Errorf("failed to open file in zip: %w", err) - } + target := filepath.Join(destDir, clean) + if !IsValidPath(target, destDir) { + logger.Warnf("skipping file path outside extraction directory: %s", target) + return nil + } - err = os.MkdirAll(filepath.Dir(target), 0o700) - if err != nil { - zf.Close() - return fmt.Errorf("failed to create directory: %w", err) - } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) - if err != nil { - out.Close() - return fmt.Errorf("failed to create file: %w", err) - } + if file.Mode().IsDir() { + return os.MkdirAll(target, 0o700) + } - written, err := io.Copy(out, io.LimitReader(zf, maxBytes)) - if err != nil { - return fmt.Errorf("failed to copy file: %w", err) - } - if written >= maxBytes { - return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) - } + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create directory structure: %w", err) + } - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) - } + src, err := file.Open() + if err != nil { + return fmt.Errorf("failed to open archived file: %w", err) + } + defer src.Close() - if err := zf.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) + dst, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + + var closeErr error + defer func() { + if cerr := dst.Close(); cerr != nil && closeErr == nil { + closeErr = cerr } + }() + + written, err := io.CopyBuffer(dst, io.LimitReader(src, maxBytes), *buf) + if err != nil { + return fmt.Errorf("failed to copy file contents: %w", err) } - return nil + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } + + return closeErr } From aaa74d34237dbecd15eefa5d1da33e2d9a5fa5f4 Mon Sep 17 00:00:00 2001 From: egibs <20933572+egibs@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:41:33 -0600 Subject: [PATCH 2/4] Use buffer for all io.Copy operations; add limit reader for .xz files Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/archive/archive.go | 18 +++++++++++++++++- pkg/archive/bz2.go | 8 +++++++- pkg/archive/deb.go | 6 ++++++ pkg/archive/gzip.go | 8 +++++++- pkg/archive/rpm.go | 8 +++++++- pkg/archive/tar.go | 12 +++++++++++- pkg/archive/upx.go | 6 ++++++ pkg/archive/zip.go | 15 +++------------ pkg/archive/zlib.go | 8 +++++++- 9 files changed, 71 insertions(+), 18 deletions(-) diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index d21886427..1859e2c32 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -15,6 +15,16 @@ import ( "github.com/chainguard-dev/malcontent/pkg/programkind" ) +const bufferSize = 32 * 1024 + +// Shared buffer pool for io.CopyBuffer operations. +var bufferPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, bufferSize) + return &b + }, +} + // isValidPath checks if the target file is within the given directory. func IsValidPath(target, dir string) bool { return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) @@ -223,6 +233,12 @@ func handleDirectory(target string) error { // handleFile extracts valid files within .deb or .tar archives. func handleFile(target string, tr *tar.Reader) error { + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { return fmt.Errorf("failed to create parent directory: %w", err) } @@ -233,7 +249,7 @@ func handleFile(target string, tr *tar.Reader) error { } defer out.Close() - written, err := io.Copy(out, io.LimitReader(tr, maxBytes)) + written, err := io.CopyBuffer(out, io.LimitReader(tr, maxBytes), *buf) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } diff --git a/pkg/archive/bz2.go b/pkg/archive/bz2.go index d99e5b57a..3dabb434c 100644 --- a/pkg/archive/bz2.go +++ b/pkg/archive/bz2.go @@ -17,6 +17,12 @@ func ExtractBz2(ctx context.Context, d, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting bzip2 file") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + // Check if the file is valid _, err := os.Stat(f) if err != nil { @@ -53,7 +59,7 @@ func ExtractBz2(ctx context.Context, d, f string) error { } defer out.Close() - written, err := io.Copy(out, io.LimitReader(br, maxBytes)) + written, err := io.CopyBuffer(out, io.LimitReader(br, maxBytes), *buf) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } diff --git a/pkg/archive/deb.go b/pkg/archive/deb.go index d1c320619..03be06d4e 100644 --- a/pkg/archive/deb.go +++ b/pkg/archive/deb.go @@ -19,6 +19,12 @@ func ExtractDeb(ctx context.Context, d, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting deb") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + fd, err := os.Open(f) if err != nil { panic(err) diff --git a/pkg/archive/gzip.go b/pkg/archive/gzip.go index 41976505b..f1e601326 100644 --- a/pkg/archive/gzip.go +++ b/pkg/archive/gzip.go @@ -29,6 +29,12 @@ func ExtractGzip(ctx context.Context, d string, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting gzip") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + // Check if the file is valid _, err := os.Stat(f) if err != nil { @@ -59,7 +65,7 @@ func ExtractGzip(ctx context.Context, d string, f string) error { } defer out.Close() - written, err := io.Copy(out, io.LimitReader(gr, maxBytes)) + written, err := io.CopyBuffer(out, io.LimitReader(gr, maxBytes), *buf) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } diff --git a/pkg/archive/rpm.go b/pkg/archive/rpm.go index e0629c6ad..400640d80 100644 --- a/pkg/archive/rpm.go +++ b/pkg/archive/rpm.go @@ -22,6 +22,12 @@ func ExtractRPM(ctx context.Context, d, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting rpm") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + rpmFile, err := os.Open(f) if err != nil { return fmt.Errorf("failed to open RPM file: %w", err) @@ -106,7 +112,7 @@ func ExtractRPM(ctx context.Context, d, f string) error { return fmt.Errorf("failed to create file: %w", err) } - written, err := io.Copy(out, io.LimitReader(cr, maxBytes)) + written, err := io.CopyBuffer(out, io.LimitReader(cr, maxBytes), *buf) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } diff --git a/pkg/archive/tar.go b/pkg/archive/tar.go index 60c12c8d0..2eb9fcb55 100644 --- a/pkg/archive/tar.go +++ b/pkg/archive/tar.go @@ -22,6 +22,12 @@ func ExtractTar(ctx context.Context, d string, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting tar") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + // Check if the file is valid _, err := os.Stat(f) if err != nil { @@ -83,9 +89,13 @@ func ExtractTar(ctx context.Context, d string, f string) error { } defer out.Close() - if _, err = io.Copy(out, xzStream); err != nil { + written, err := io.CopyBuffer(out, io.LimitReader(xzStream, maxBytes), *buf) + if err != nil { return fmt.Errorf("failed to write decompressed xz output: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } return nil case strings.Contains(filename, ".tar.bz2") || strings.Contains(filename, ".tbz"): br := bzip2.NewReader(tf) diff --git a/pkg/archive/upx.go b/pkg/archive/upx.go index fad9ea327..d3493ca83 100644 --- a/pkg/archive/upx.go +++ b/pkg/archive/upx.go @@ -20,6 +20,12 @@ func ExtractUPX(ctx context.Context, d, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting upx") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + // Check if the file is valid _, err := os.Stat(f) if err != nil { diff --git a/pkg/archive/zip.go b/pkg/archive/zip.go index 02d65b157..ba400f19f 100644 --- a/pkg/archive/zip.go +++ b/pkg/archive/zip.go @@ -9,21 +9,12 @@ import ( "path/filepath" "runtime" "strings" - "sync" "github.com/chainguard-dev/clog" "golang.org/x/sync/errgroup" ) -const bufferSize = 32 * 1024 - -var bufferPool = sync.Pool{ - New: func() interface{} { - b := make([]byte, bufferSize) - return &b - }, -} - +// ExtractZip extracts .jar and .zip archives. func ExtractZip(ctx context.Context, d string, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting zip") @@ -46,12 +37,12 @@ func ExtractZip(ctx context.Context, d string, f string) error { return fmt.Errorf("failed to create extraction directory: %w", err) } - g, ctx := errgroup.WithContext(ctx) + g, gCtx := errgroup.WithContext(ctx) g.SetLimit(runtime.GOMAXPROCS(0)) for _, file := range read.File { g.Go(func() error { - return extractFile(ctx, file, d, logger) + return extractFile(gCtx, file, d, logger) }) } diff --git a/pkg/archive/zlib.go b/pkg/archive/zlib.go index 2def0b4ce..679c1476a 100644 --- a/pkg/archive/zlib.go +++ b/pkg/archive/zlib.go @@ -16,6 +16,12 @@ func ExtractZlib(ctx context.Context, d string, f string) error { logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debugf("extracting zlib") + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + return fmt.Errorf("failed to retrieve buffer") + } + defer bufferPool.Put(buf) + // Check if the file is valid _, err := os.Stat(f) if err != nil { @@ -43,7 +49,7 @@ func ExtractZlib(ctx context.Context, d string, f string) error { } defer out.Close() - written, err := io.Copy(out, io.LimitReader(zr, maxBytes)) + written, err := io.CopyBuffer(out, io.LimitReader(zr, maxBytes), *buf) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } From 2ed60e97d022ac306dc8d4599b6e73e28b10b473 Mon Sep 17 00:00:00 2001 From: egibs <20933572+egibs@users.noreply.github.com> Date: Sat, 25 Jan 2025 15:12:27 -0600 Subject: [PATCH 3/4] Consolidate consts Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/archive/archive.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 1859e2c32..48ff70b2f 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -15,7 +15,12 @@ import ( "github.com/chainguard-dev/malcontent/pkg/programkind" ) -const bufferSize = 32 * 1024 +const ( + // 32KB buffer. + bufferSize = 32 * 1024 + // 512MB file limit. + maxBytes = 1 << 29 +) // Shared buffer pool for io.CopyBuffer operations. var bufferPool = sync.Pool{ @@ -30,8 +35,6 @@ func IsValidPath(target, dir string) bool { return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) } -const maxBytes = 1 << 29 // 512MB - func extractNestedArchive( ctx context.Context, d string, From 54864e1951c9ffc88a815a0d75f8e02daedb6458 Mon Sep 17 00:00:00 2001 From: egibs <20933572+egibs@users.noreply.github.com> Date: Sat, 25 Jan 2025 15:23:34 -0600 Subject: [PATCH 4/4] Keep parameters on a single line Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/archive/archive.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 48ff70b2f..9cbe63788 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -35,12 +35,7 @@ func IsValidPath(target, dir string) bool { return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) } -func extractNestedArchive( - ctx context.Context, - d string, - f string, - extracted *sync.Map, -) error { +func extractNestedArchive(ctx context.Context, d string, f string, extracted *sync.Map) error { isArchive := false // zlib-compressed files are also archives ft, err := programkind.File(f)