Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -187,9 +186,9 @@ func TestArchiveFailure(t *testing.T) {

func TestExtract(t *testing.T) {
for _, test := range ExtractCases {
dir, _ := ioutil.TempDir("", "")
dir, _ := os.MkdirTemp("", "")
dir = filepath.Join(dir, "test")
data, err := ioutil.ReadFile(test.Archive)
data, err := os.ReadFile(test.Archive)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -222,8 +221,8 @@ func TestExtract(t *testing.T) {
}

func BenchmarkArchive(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.bz2")

b.StartTimer()

Expand All @@ -244,8 +243,8 @@ func BenchmarkArchive(b *testing.B) {
}

func BenchmarkTarBz2(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.bz2")

b.StartTimer()

Expand All @@ -266,8 +265,8 @@ func BenchmarkTarBz2(b *testing.B) {
}

func BenchmarkTarGz(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.gz")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.gz")

b.StartTimer()

Expand All @@ -288,8 +287,8 @@ func BenchmarkTarGz(b *testing.B) {
}

func BenchmarkZip(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.zip")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.zip")

b.StartTimer()

Expand Down Expand Up @@ -319,7 +318,7 @@ func testWalk(t *testing.T, dir string, testFiles Files) {
} else if info.Mode()&os.ModeSymlink != 0 {
files[path] = "link"
} else {
data, err := ioutil.ReadFile(filepath.Join(dir, path))
data, err := os.ReadFile(filepath.Join(dir, path))
require.NoError(t, err)
files[path] = strings.TrimSpace(string(data))
}
Expand Down Expand Up @@ -370,7 +369,7 @@ func TestTarGzMemoryConsumption(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&m)

err = extract.Gz(context.Background(), f, tmpDir.String(), nil)
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
require.NoError(t, err)

runtime.ReadMemStats(&m2)
Expand Down Expand Up @@ -398,7 +397,7 @@ func TestZipMemoryConsumption(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&m)

err = extract.Zip(context.Background(), f, tmpDir.String(), nil)
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
require.NoError(t, err)

runtime.ReadMemStats(&m2)
Expand All @@ -407,9 +406,7 @@ func TestZipMemoryConsumption(t *testing.T) {
heapUsed = 0
}
fmt.Println("Heap memory used during the test:", heapUsed)
// the .zip file require random access, so the full io.Reader content must be cached, since
// the test file is 130MB, that's the reason for the high memory consumed.
require.True(t, heapUsed < 250000000, "heap consumption should be less than 250M but is %d", heapUsed)
require.True(t, heapUsed < 10000000, "heap consumption should be less than 10M but is %d", heapUsed)
}

func download(t require.TestingT, url string, file *paths.Path) error {
Expand Down
40 changes: 32 additions & 8 deletions extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"compress/bzip2"
"compress/gzip"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -237,11 +237,27 @@ func (e *Extractor) Tar(ctx context.Context, body io.Reader, location string, re
// Zip extracts a .zip archived stream of data in the specified location.
// It accepts a rename function to handle the names of the files (see the example).
func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, rename Renamer) error {
// read the whole body into a buffer. Not sure this is the best way to do it
buffer := bytes.NewBuffer([]byte{})
copyCancel(ctx, buffer, body)

archive, err := zip.NewReader(bytes.NewReader(buffer.Bytes()), int64(buffer.Len()))
var bodySize int64
bodyReaderAt, isReaderAt := (body).(io.ReaderAt)
if bodySeeker, isSeeker := (body).(io.Seeker); isReaderAt && isSeeker {
// get the size by seeking to the end
endPos, err := bodySeeker.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to seek to the end of the body: %s", err)
}
// reset the reader to the beginning
if _, err := bodySeeker.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek to the beginning of the body: %w", err)
}
bodySize = endPos
} else {
// read the whole body into a buffer. Not sure this is the best way to do it
buffer := bytes.NewBuffer([]byte{})
copyCancel(ctx, buffer, body)
bodyReaderAt = bytes.NewReader(buffer.Bytes())
bodySize = int64(buffer.Len())
}
archive, err := zip.NewReader(bodyReaderAt, bodySize)
if err != nil {
return errors.Annotatef(err, "Read the zip file")
}
Expand Down Expand Up @@ -290,7 +306,7 @@ func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, re
case info.Mode()&os.ModeSymlink != 0:
if f, err := header.Open(); err != nil {
return errors.Annotatef(err, "Open link %s", path)
} else if name, err := ioutil.ReadAll(f); err != nil {
} else if name, err := io.ReadAll(f); err != nil {
return errors.Annotatef(err, "Read address of link %s", path)
} else {
links = append(links, link{Path: path, Name: string(name)})
Expand Down Expand Up @@ -347,7 +363,15 @@ func match(r io.Reader) (io.Reader, types.Type, error) {
return nil, types.Unknown, err
}

r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
if seeker, ok := r.(io.Seeker); ok {
// if the stream is seekable, we just rewind it
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
return nil, types.Unknown, err
}
} else {
// otherwise we create a new reader that will prepend the buffer
r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
}

typ, err := filetype.Match(buffer)

Expand Down