diff --git a/.changelog/9cf590d7376444d08d7d959651964337.json b/.changelog/9cf590d7376444d08d7d959651964337.json new file mode 100644 index 00000000000..9e7a0566fa6 --- /dev/null +++ b/.changelog/9cf590d7376444d08d7d959651964337.json @@ -0,0 +1,8 @@ +{ + "id": "9cf590d7-3764-44d0-8d7d-959651964337", + "type": "feature", + "description": "Add durability checks to validate part count and range for upload/download. You can disable this with `DisableValidateParts` in upload/download options, though doing so is not recommended because it damages the durability posture of your application.", + "modules": [ + "feature/s3/manager" + ] +} \ No newline at end of file diff --git a/feature/s3/manager/download.go b/feature/s3/manager/download.go index 8acd9a27aea..8e3b6a309fb 100644 --- a/feature/s3/manager/download.go +++ b/feature/s3/manager/download.go @@ -77,6 +77,14 @@ type Downloader struct { // operation requests made by the downloader. ClientOptions []func(*s3.Options) + // By default, the downloader verifies that individual part ranges align + // based on the configured part size. + // + // You can disable that with this flag, however, Amazon S3 recommends + // against doing so because it damages the durability posture of object + // downloads. + DisableValidateParts bool + // Defines the buffer strategy used when downloading a part. // // If a WriterReadFromProvider is given the Download manager @@ -404,6 +412,15 @@ func (d *downloader) tryDownloadChunk(params *s3.GetObjectInput, w io.Writer) (i if err != nil { return 0, err } + + if !d.cfg.DisableValidateParts && params.Range != nil && resp.ContentRange != nil { + expectStart, expectEnd := parseContentRange(*params.Range) + actualStart, actualEnd := parseContentRange(*resp.ContentRange) + if isRangeMismatch(expectStart, expectEnd, actualStart, actualEnd) { + return 0, fmt.Errorf("invalid content range: expect %d-%d, got %d-%d", expectStart, expectEnd, actualStart, actualEnd) + } + } + d.setTotalBytes(resp) // Set total if not yet set. d.once.Do(func() { d.etag = aws.ToString(resp.ETag) @@ -422,6 +439,46 @@ func (d *downloader) tryDownloadChunk(params *s3.GetObjectInput, w io.Writer) (i return n, nil } +func parseContentRange(v string) (int, int) { + parts := strings.Split(v, "/") // chop the total off, if it's there + + // we send "bytes=" but S3 appears to return "bytes ", handle both + trimmed := strings.TrimPrefix(parts[0], "bytes ") + trimmed = strings.TrimPrefix(trimmed, "bytes=") + + parts = strings.Split(trimmed, "-") + if len(parts) != 2 { + return -1, -1 + } + + start, err := strconv.Atoi(parts[0]) + if err != nil { + return -1, -1 + } + + end, err := strconv.Atoi(parts[1]) + if err != nil { + return -1, -1 + } + + return start, end +} + +func isRangeMismatch(expectStart, expectEnd, actualStart, actualEnd int) bool { + if expectStart == -1 || expectEnd == -1 || actualStart == -1 || actualEnd == -1 { + return false // we don't know, one of the ranges was missing or unparseable + } + + // for the final chunk (or the first chunk if it's smaller) we still + // request a full chunk but we get back the actual final part of the + // object, which will be smaller + if expectStart == actualStart && actualEnd < expectEnd { + return false + } + + return expectStart != actualStart || expectEnd != actualEnd +} + // getTotalBytes is a thread-safe getter for retrieving the total byte status. func (d *downloader) getTotalBytes() int64 { d.m.Lock() diff --git a/feature/s3/manager/download_test.go b/feature/s3/manager/download_test.go index 78b65bf7ddb..112e36def4d 100644 --- a/feature/s3/manager/download_test.go +++ b/feature/s3/manager/download_test.go @@ -94,6 +94,30 @@ func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) { return capture, &capture.GetObjectInvocations } +func newDownloadBadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) { + capture := &downloadCaptureClient{} + + capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + start, fin := parseRange(aws.ToString(params.Range)) + fin++ + + if fin >= int64(len(data)) { + fin = int64(len(data)) + } + + bodyBytes := data[start:fin] + + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), + // offset start by 1 to make it wrong + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start+1, fin-1, len(data))), + ContentLength: aws.Int64(int64(len(bodyBytes))), + }, nil + } + + return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges +} + func newDownloadVersionClient(data []byte) (*downloadCaptureClient, *int, *[]string, *[]string) { capture := &downloadCaptureClient{} diff --git a/feature/s3/manager/upload.go b/feature/s3/manager/upload.go index 6f513f537cf..dcc439c9d78 100644 --- a/feature/s3/manager/upload.go +++ b/feature/s3/manager/upload.go @@ -269,6 +269,14 @@ type Uploader struct { // Note: S3 Express buckets always require CRC32 checksums regardless of this setting. RequestChecksumCalculation aws.RequestChecksumCalculation + // By default, the uploader verifies that the number of expected uploaded + // parts matches the actual count at the end of an upload. + // + // You can disable that with this flag, however, Amazon S3 recommends + // against doing so because it damages the durability posture of object + // uploads. + DisableValidateParts bool + // partPool allows for the re-usage of streaming payload part buffers between upload calls partPool byteSlicePool } @@ -362,8 +370,9 @@ type uploader struct { in *s3.PutObjectInput - readerPos int64 // current reader position - totalSize int64 // set to -1 if the size is not known + readerPos int64 // current reader position + totalSize int64 // set to -1 if the size is not known + expectParts int64 } // internal logic for deciding whether to upload a single part or use a @@ -446,6 +455,11 @@ func (u *uploader) initSize() error { // during the size calculation. e.g odd number of bytes. u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1 } + + u.expectParts = u.totalSize / u.cfg.PartSize + if u.totalSize%u.cfg.PartSize != 0 { + u.expectParts++ + } } return nil @@ -877,6 +891,17 @@ func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput { u.fail() } + // expectParts == 0 means we didn't know the content length upfront and + // therefore we can't validate this at all + if u.expectParts == 0 || u.cfg.DisableValidateParts { + return resp + } + + if len(u.parts) != int(u.expectParts) { + u.seterr(fmt.Errorf("uploaded part count mismatch: expected %d, got %d", u.expectParts, len(u.parts))) + u.fail() + } + return resp } diff --git a/feature/s3/manager/validate_parts_test.go b/feature/s3/manager/validate_parts_test.go new file mode 100644 index 00000000000..ef356aeb9bb --- /dev/null +++ b/feature/s3/manager/validate_parts_test.go @@ -0,0 +1,52 @@ +package manager_test + +import ( + "context" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type invalidRangeClient struct { +} + +func TestDownload_RangeMismatch(t *testing.T) { + c, _, _ := newDownloadBadRangeClient(buf12MB) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + + w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB))) + _, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err == nil { + t.Fatalf("expect err, got none") + } + if !strings.Contains(err.Error(), "invalid content range") { + t.Errorf("error mismatch:\n%v !=\n%v", err, "invalid content range") + } +} + +func TestDownload_RangeMismatchDisabled(t *testing.T) { + c, _, _ := newDownloadBadRangeClient(buf12MB) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + d.DisableValidateParts = true + }) + + w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB))) + _, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err != nil { + t.Fatalf("expect no err, got %v", err) + } +}