diff --git a/extract_test.go b/extract_test.go index 7666dce..3ad01f2 100644 --- a/extract_test.go +++ b/extract_test.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "net/http" "os" "path/filepath" @@ -187,9 +186,9 @@ func TestArchiveFailure(t *testing.T) { func TestExtract(t *testing.T) { for _, test := range ExtractCases { - dir, _ := ioutil.TempDir("", "") + dir, _ := os.MkdirTemp("", "") dir = filepath.Join(dir, "test") - data, err := ioutil.ReadFile(test.Archive) + data, err := os.ReadFile(test.Archive) if err != nil { t.Fatal(err) } @@ -222,8 +221,8 @@ func TestExtract(t *testing.T) { } func BenchmarkArchive(b *testing.B) { - dir, _ := ioutil.TempDir("", "") - data, _ := ioutil.ReadFile("testdata/archive.tar.bz2") + dir, _ := os.MkdirTemp("", "") + data, _ := os.ReadFile("testdata/archive.tar.bz2") b.StartTimer() @@ -244,8 +243,8 @@ func BenchmarkArchive(b *testing.B) { } func BenchmarkTarBz2(b *testing.B) { - dir, _ := ioutil.TempDir("", "") - data, _ := ioutil.ReadFile("testdata/archive.tar.bz2") + dir, _ := os.MkdirTemp("", "") + data, _ := os.ReadFile("testdata/archive.tar.bz2") b.StartTimer() @@ -266,8 +265,8 @@ func BenchmarkTarBz2(b *testing.B) { } func BenchmarkTarGz(b *testing.B) { - dir, _ := ioutil.TempDir("", "") - data, _ := ioutil.ReadFile("testdata/archive.tar.gz") + dir, _ := os.MkdirTemp("", "") + data, _ := os.ReadFile("testdata/archive.tar.gz") b.StartTimer() @@ -288,8 +287,8 @@ func BenchmarkTarGz(b *testing.B) { } func BenchmarkZip(b *testing.B) { - dir, _ := ioutil.TempDir("", "") - data, _ := ioutil.ReadFile("testdata/archive.zip") + dir, _ := os.MkdirTemp("", "") + data, _ := os.ReadFile("testdata/archive.zip") b.StartTimer() @@ -319,7 +318,7 @@ func testWalk(t *testing.T, dir string, testFiles Files) { } else if info.Mode()&os.ModeSymlink != 0 { files[path] = "link" } else { - data, err := ioutil.ReadFile(filepath.Join(dir, path)) + data, err := os.ReadFile(filepath.Join(dir, path)) require.NoError(t, err) files[path] = strings.TrimSpace(string(data)) } @@ -370,7 +369,7 @@ func TestTarGzMemoryConsumption(t *testing.T) { runtime.GC() runtime.ReadMemStats(&m) - err = extract.Gz(context.Background(), f, tmpDir.String(), nil) + err = extract.Archive(context.Background(), f, tmpDir.String(), nil) require.NoError(t, err) runtime.ReadMemStats(&m2) @@ -398,7 +397,7 @@ func TestZipMemoryConsumption(t *testing.T) { runtime.GC() runtime.ReadMemStats(&m) - err = extract.Zip(context.Background(), f, tmpDir.String(), nil) + err = extract.Archive(context.Background(), f, tmpDir.String(), nil) require.NoError(t, err) runtime.ReadMemStats(&m2) @@ -407,9 +406,7 @@ func TestZipMemoryConsumption(t *testing.T) { heapUsed = 0 } fmt.Println("Heap memory used during the test:", heapUsed) - // the .zip file require random access, so the full io.Reader content must be cached, since - // the test file is 130MB, that's the reason for the high memory consumed. - require.True(t, heapUsed < 250000000, "heap consumption should be less than 250M but is %d", heapUsed) + require.True(t, heapUsed < 10000000, "heap consumption should be less than 10M but is %d", heapUsed) } func download(t require.TestingT, url string, file *paths.Path) error { diff --git a/extractor.go b/extractor.go index 2151573..ccd6b3f 100644 --- a/extractor.go +++ b/extractor.go @@ -7,8 +7,8 @@ import ( "compress/bzip2" "compress/gzip" "context" + "fmt" "io" - "io/ioutil" "os" "path/filepath" "strings" @@ -237,11 +237,27 @@ func (e *Extractor) Tar(ctx context.Context, body io.Reader, location string, re // Zip extracts a .zip archived stream of data in the specified location. // It accepts a rename function to handle the names of the files (see the example). func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, rename Renamer) error { - // read the whole body into a buffer. Not sure this is the best way to do it - buffer := bytes.NewBuffer([]byte{}) - copyCancel(ctx, buffer, body) - - archive, err := zip.NewReader(bytes.NewReader(buffer.Bytes()), int64(buffer.Len())) + var bodySize int64 + bodyReaderAt, isReaderAt := (body).(io.ReaderAt) + if bodySeeker, isSeeker := (body).(io.Seeker); isReaderAt && isSeeker { + // get the size by seeking to the end + endPos, err := bodySeeker.Seek(0, io.SeekEnd) + if err != nil { + return fmt.Errorf("failed to seek to the end of the body: %s", err) + } + // reset the reader to the beginning + if _, err := bodySeeker.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek to the beginning of the body: %w", err) + } + bodySize = endPos + } else { + // read the whole body into a buffer. Not sure this is the best way to do it + buffer := bytes.NewBuffer([]byte{}) + copyCancel(ctx, buffer, body) + bodyReaderAt = bytes.NewReader(buffer.Bytes()) + bodySize = int64(buffer.Len()) + } + archive, err := zip.NewReader(bodyReaderAt, bodySize) if err != nil { return errors.Annotatef(err, "Read the zip file") } @@ -290,7 +306,7 @@ func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, re case info.Mode()&os.ModeSymlink != 0: if f, err := header.Open(); err != nil { return errors.Annotatef(err, "Open link %s", path) - } else if name, err := ioutil.ReadAll(f); err != nil { + } else if name, err := io.ReadAll(f); err != nil { return errors.Annotatef(err, "Read address of link %s", path) } else { links = append(links, link{Path: path, Name: string(name)}) @@ -347,7 +363,15 @@ func match(r io.Reader) (io.Reader, types.Type, error) { return nil, types.Unknown, err } - r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r) + if seeker, ok := r.(io.Seeker); ok { + // if the stream is seekable, we just rewind it + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return nil, types.Unknown, err + } + } else { + // otherwise we create a new reader that will prepend the buffer + r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r) + } typ, err := filetype.Match(buffer)