From be82f001c7eca23802440eb1294ddec6c5cbec33 Mon Sep 17 00:00:00 2001 From: zhijian Date: Thu, 30 Nov 2023 14:48:27 +0800 Subject: [PATCH] object: fix verify checksum (#4213) --- pkg/object/checksum.go | 12 +++++---- pkg/object/checksum_test.go | 53 +++++++++++++++++++++++++++++++++++++ pkg/object/cos.go | 7 ++++- pkg/object/oss.go | 8 +++++- pkg/object/s3.go | 6 ++++- 5 files changed, 78 insertions(+), 8 deletions(-) diff --git a/pkg/object/checksum.go b/pkg/object/checksum.go index 8ef4f27769a5..c7db5a38357a 100644 --- a/pkg/object/checksum.go +++ b/pkg/object/checksum.go @@ -54,20 +54,22 @@ func generateChecksum(in io.ReadSeeker) string { type checksumReader struct { io.ReadCloser - expected uint32 - checksum uint32 + expected uint32 + checksum uint32 + remainingLength int64 } func (c *checksumReader) Read(buf []byte) (n int, err error) { n, err = c.ReadCloser.Read(buf) c.checksum = crc32.Update(c.checksum, crc32c, buf[:n]) - if err == io.EOF && c.checksum != c.expected { + c.remainingLength -= int64(n) + if (err == io.EOF || c.remainingLength == 0) && c.checksum != c.expected { return 0, fmt.Errorf("verify checksum failed: %d != %d", c.checksum, c.expected) } return } -func verifyChecksum(in io.ReadCloser, checksum string) io.ReadCloser { +func verifyChecksum(in io.ReadCloser, checksum string, contentLength int64) io.ReadCloser { if checksum == "" { return in } @@ -76,5 +78,5 @@ func verifyChecksum(in io.ReadCloser, checksum string) io.ReadCloser { logger.Errorf("invalid crc32c: %s", checksum) return in } - return &checksumReader{in, uint32(expected), 0} + return &checksumReader{in, uint32(expected), 0, contentLength} } diff --git a/pkg/object/checksum_test.go b/pkg/object/checksum_test.go index 636917ed6233..350861938f65 100644 --- a/pkg/object/checksum_test.go +++ b/pkg/object/checksum_test.go @@ -18,8 +18,11 @@ package object import ( "bytes" + "crypto/rand" "hash/crc32" + "io" "strconv" + "strings" "testing" ) @@ -38,3 +41,53 @@ func TestChecksum(t *testing.T) { t.FailNow() } } + +func TestChecksumRead(t *testing.T) { + length := 10240 + content := make([]byte, length) + if _, err := rand.Read(content); err != nil { + t.Fatalf("Generate random content: %s", err) + } + actual := generateChecksum(bytes.NewReader(content)) + + // content length equal buff length case + lens := []int64{-1, int64(length)} + for _, contentLength := range lens { + reader := verifyChecksum(io.NopCloser(bytes.NewReader(content)), actual, contentLength) + n, err := reader.Read(make([]byte, length)) + if n != length || (err != nil && err != io.EOF) { + t.Fatalf("verify checksum shuold success") + } + } + + // verify success case + for _, contentLength := range lens { + reader := verifyChecksum(io.NopCloser(bytes.NewReader(content)), actual, contentLength) + n, err := reader.Read(make([]byte, length+100)) + if n != length || (err != nil && err != io.EOF) { + t.Fatalf("verify checksum shuold success") + } + } + + // verify failed case + for _, contentLength := range lens { + content[0] = 'a' + reader := verifyChecksum(io.NopCloser(bytes.NewReader(content)), actual, contentLength) + n, err := reader.Read(make([]byte, length)) + if contentLength == -1 && (err != nil && err != io.EOF || n != length) { + t.Fatalf("dont verify checksum when content length is -1") + } + if contentLength != -1 && (err == nil || err == io.EOF || !strings.HasPrefix(err.Error(), "verify checksum failed")) { + t.Fatalf("verify checksum should failed") + } + } + + // verify read length less than content length case + for _, contentLength := range lens { + reader := verifyChecksum(io.NopCloser(bytes.NewReader(content)), actual, contentLength) + n, err := reader.Read(make([]byte, length-100)) + if err != nil || n != length-100 { + t.Fatalf("error should be nil and read length should be %d", length-100) + } + } +} diff --git a/pkg/object/cos.go b/pkg/object/cos.go index 4103c5125894..35ba489fb03f 100644 --- a/pkg/object/cos.go +++ b/pkg/object/cos.go @@ -108,7 +108,12 @@ func (c *COS) Get(key string, off, limit int64) (io.ReadCloser, error) { return nil, err } if off == 0 && limit == -1 { - resp.Body = verifyChecksum(resp.Body, resp.Header.Get(cosChecksumKey)) + length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + length = -1 + logger.Warnf("failed to parse content-length %s: %s", resp.Header.Get("Content-Length"), err) + } + resp.Body = verifyChecksum(resp.Body, resp.Header.Get(cosChecksumKey), length) } if resp != nil { ReqIDCache.put(key, resp.Header.Get(cosRequestIDKey)) diff --git a/pkg/object/oss.go b/pkg/object/oss.go index 59ec96996727..13f84d381d48 100644 --- a/pkg/object/oss.go +++ b/pkg/object/oss.go @@ -126,8 +126,14 @@ func (o *ossClient) Get(key string, off, limit int64) (resp io.ReadCloser, err e } else { resp, err = o.bucket.GetObject(key, oss.GetResponseHeader(&respHeader)) if err == nil { + length, err := strconv.ParseInt(resp.(*oss.Response).Headers.Get(oss.HTTPHeaderContentLength), 10, 64) + if err != nil { + length = -1 + logger.Warnf("failed to parse content-length %s: %s", resp.(*oss.Response).Headers.Get(oss.HTTPHeaderContentLength), err) + } resp = verifyChecksum(resp, - resp.(*oss.Response).Headers.Get(oss.HTTPHeaderOssMetaPrefix+checksumAlgr)) + resp.(*oss.Response).Headers.Get(oss.HTTPHeaderOssMetaPrefix+checksumAlgr), + length) } } ReqIDCache.put(key, respHeader.Get(oss.HTTPHeaderOssRequestID)) diff --git a/pkg/object/s3.go b/pkg/object/s3.go index 059818ce4086..0f5efd0f4328 100644 --- a/pkg/object/s3.go +++ b/pkg/object/s3.go @@ -136,8 +136,12 @@ func (s *s3client) Get(key string, off, limit int64) (io.ReadCloser, error) { } if off == 0 && limit == -1 { cs := resp.Metadata[checksumAlgr] + var length int64 = -1 + if resp.ContentLength != nil { + length = *resp.ContentLength + } if cs != nil { - resp.Body = verifyChecksum(resp.Body, *cs) + resp.Body = verifyChecksum(resp.Body, *cs, length) } } return resp.Body, nil