Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/9cf590d7376444d08d7d959651964337.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "9cf590d7-3764-44d0-8d7d-959651964337",
"type": "feature",
"description": "Add durability checks to validate part count and range for upload/download. You can disable this with `DisableValidateParts` in upload/download options, though doing so is not recommended because it damages the durability posture of your application.",
"modules": [
"feature/s3/manager"
]
}
57 changes: 57 additions & 0 deletions feature/s3/manager/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ type Downloader struct {
// operation requests made by the downloader.
ClientOptions []func(*s3.Options)

// By default, the downloader verifies that individual part ranges align
// based on the configured part size.
//
// You can disable that with this flag, however, Amazon S3 recommends
// against doing so because it damages the durability posture of object
// downloads.
DisableValidateParts bool

// Defines the buffer strategy used when downloading a part.
//
// If a WriterReadFromProvider is given the Download manager
Expand Down Expand Up @@ -404,6 +412,15 @@ func (d *downloader) tryDownloadChunk(params *s3.GetObjectInput, w io.Writer) (i
if err != nil {
return 0, err
}

if !d.cfg.DisableValidateParts && params.Range != nil && resp.ContentRange != nil {
Copy link
Contributor

@wty-Bryant wty-Bryant Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: IIRC params.Range will always be set at old line 361 in downloadChunk

expectStart, expectEnd := parseContentRange(*params.Range)
actualStart, actualEnd := parseContentRange(*resp.ContentRange)
if isRangeMismatch(expectStart, expectEnd, actualStart, actualEnd) {
return 0, fmt.Errorf("invalid content range: expect %d-%d, got %d-%d", expectStart, expectEnd, actualStart, actualEnd)
}
}

d.setTotalBytes(resp) // Set total if not yet set.
d.once.Do(func() {
d.etag = aws.ToString(resp.ETag)
Expand All @@ -422,6 +439,46 @@ func (d *downloader) tryDownloadChunk(params *s3.GetObjectInput, w io.Writer) (i
return n, nil
}

func parseContentRange(v string) (int, int) {
parts := strings.Split(v, "/") // chop the total off, if it's there

// we send "bytes=" but S3 appears to return "bytes ", handle both
trimmed := strings.TrimPrefix(parts[0], "bytes ")
trimmed = strings.TrimPrefix(trimmed, "bytes=")

parts = strings.Split(trimmed, "-")
if len(parts) != 2 {
return -1, -1
}

start, err := strconv.Atoi(parts[0])
if err != nil {
return -1, -1
}

end, err := strconv.Atoi(parts[1])
if err != nil {
return -1, -1
}

return start, end
}

func isRangeMismatch(expectStart, expectEnd, actualStart, actualEnd int) bool {
if expectStart == -1 || expectEnd == -1 || actualStart == -1 || actualEnd == -1 {
return false // we don't know, one of the ranges was missing or unparseable
}

// for the final chunk (or the first chunk if it's smaller) we still
// request a full chunk but we get back the actual final part of the
// object, which will be smaller
if expectStart == actualStart && actualEnd < expectEnd {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice capture

return false
}

return expectStart != actualStart || expectEnd != actualEnd
}

// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
Expand Down
24 changes: 24 additions & 0 deletions feature/s3/manager/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,30 @@ func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) {
return capture, &capture.GetObjectInvocations
}

func newDownloadBadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) {
capture := &downloadCaptureClient{}

capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
start, fin := parseRange(aws.ToString(params.Range))
fin++

if fin >= int64(len(data)) {
fin = int64(len(data))
}

bodyBytes := data[start:fin]

return &s3.GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
// offset start by 1 to make it wrong
ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start+1, fin-1, len(data))),
ContentLength: aws.Int64(int64(len(bodyBytes))),
}, nil
}

return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges
}

func newDownloadVersionClient(data []byte) (*downloadCaptureClient, *int, *[]string, *[]string) {
capture := &downloadCaptureClient{}

Expand Down
29 changes: 27 additions & 2 deletions feature/s3/manager/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ type Uploader struct {
// Note: S3 Express buckets always require CRC32 checksums regardless of this setting.
RequestChecksumCalculation aws.RequestChecksumCalculation

// By default, the uploader verifies that the number of expected uploaded
// parts matches the actual count at the end of an upload.
//
// You can disable that with this flag, however, Amazon S3 recommends
// against doing so because it damages the durability posture of object
// uploads.
DisableValidateParts bool

// partPool allows for the re-usage of streaming payload part buffers between upload calls
partPool byteSlicePool
}
Expand Down Expand Up @@ -362,8 +370,9 @@ type uploader struct {

in *s3.PutObjectInput

readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
expectParts int64
}

// internal logic for deciding whether to upload a single part or use a
Expand Down Expand Up @@ -446,6 +455,11 @@ func (u *uploader) initSize() error {
// during the size calculation. e.g odd number of bytes.
u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1
}

u.expectParts = u.totalSize / u.cfg.PartSize
if u.totalSize%u.cfg.PartSize != 0 {
u.expectParts++
}
}

return nil
Expand Down Expand Up @@ -877,6 +891,17 @@ func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
u.fail()
}

// expectParts == 0 means we didn't know the content length upfront and
// therefore we can't validate this at all
if u.expectParts == 0 || u.cfg.DisableValidateParts {
return resp
}

if len(u.parts) != int(u.expectParts) {
u.seterr(fmt.Errorf("uploaded part count mismatch: expected %d, got %d", u.expectParts, len(u.parts)))
u.fail()
}

return resp
}

Expand Down
52 changes: 52 additions & 0 deletions feature/s3/manager/validate_parts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package manager_test

import (
"context"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
)

type invalidRangeClient struct {
}

func TestDownload_RangeMismatch(t *testing.T) {
c, _, _ := newDownloadBadRangeClient(buf12MB)

d := manager.NewDownloader(c, func(d *manager.Downloader) {
d.Concurrency = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test when we useDisableValidateParts and ensure this does not fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

})

w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB)))
_, err := d.Download(context.Background(), w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err == nil {
t.Fatalf("expect err, got none")
}
if !strings.Contains(err.Error(), "invalid content range") {
t.Errorf("error mismatch:\n%v !=\n%v", err, "invalid content range")
}
}

func TestDownload_RangeMismatchDisabled(t *testing.T) {
c, _, _ := newDownloadBadRangeClient(buf12MB)

d := manager.NewDownloader(c, func(d *manager.Downloader) {
d.Concurrency = 1
d.DisableValidateParts = true
})

w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB)))
_, err := d.Download(context.Background(), w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no err, got %v", err)
}
}
Loading