Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,33 @@ namespace Aws
return rangeStream.str();
}

static bool VerifyContentRange(const Aws::String& requestedRange, const Aws::String& responseContentRange)
{
if (requestedRange.empty() || responseContentRange.empty())
{
return false;
}

if (requestedRange.find("bytes=") != 0)
{
return false;
}
Aws::String requestRange = requestedRange.substr(6);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

substr(6) seems like a "magic number", can we make this based on a search? a hardcoded index seems like it could break if anything changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to use strlen(requestPrefix) instead of the hardcoded value


if (responseContentRange.find("bytes ") != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= 0 is weird but valid, for searching string in cpp you should prefer the npos value i.e.

if (responseContentRange.find("bytes ") != Aws::String::npos) {/*...*/}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense - but since I'm specifically checking if the string starts with 'bytes', I'm using .substr(0, strlen(requestPrefix)) for the prefix comparison.

{
return false;
}
Aws::String responseRange = responseContentRange.substr(6);
size_t slashPos = responseRange.find('/');
if (slashPos != Aws::String::npos)
{
responseRange = responseRange.substr(0, slashPos);
}

return requestRange == responseRange;
}

void TransferManager::DoSinglePartDownload(const std::shared_ptr<TransferHandle>& handle)
{
auto queuedParts = handle->GetQueuedParts();
Expand Down Expand Up @@ -1091,7 +1118,6 @@ namespace Aws
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context)
{
AWS_UNREFERENCED_PARAM(client);
AWS_UNREFERENCED_PARAM(request);

std::shared_ptr<TransferHandleAsyncContext> transferContext =
std::const_pointer_cast<TransferHandleAsyncContext>(std::static_pointer_cast<const TransferHandleAsyncContext>(context));
Expand All @@ -1110,6 +1136,37 @@ namespace Aws
}
else
{
if (request.RangeHasBeenSet())
{
const auto& requestedRange = request.GetRange();
const auto& responseContentRange = outcome.GetResult().GetContentRange();

if (!responseContentRange.empty())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any instance where we set range on the request where s3 does not return a range? would that be a error?

{
if (!VerifyContentRange(requestedRange, responseContentRange))
{
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ContentRangeMismatch",
"ContentRange in response does not match requested range",
false);
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] ContentRange mismatch. Requested: [" << requestedRange
<< "] Received: [" << responseContentRange << "]");
handle->ChangePartToFailed(partState);
handle->SetError(error);
TriggerErrorCallback(handle, error);
handle->Cancel();

if(partState->GetDownloadBuffer())
{
m_bufferManager.Release(partState->GetDownloadBuffer());
partState->SetDownloadBuffer(nullptr);
}
return;
}
}
}

if(handle->ShouldContinue())
{
Aws::IOStream* bufferStream = partState->GetDownloadPartStream();
Expand Down
34 changes: 34 additions & 0 deletions tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,40 @@ TEST_P(TransferTests, TransferManager_TestRelativePrefix)
}
}

TEST_P(TransferTests, TransferManager_ContentRangeVerificationTest)
{
const Aws::String RandomFileName = Aws::Utils::UUID::RandomUUID();
Aws::String testFileName = MakeFilePath(RandomFileName.c_str());
ScopedTestFile testFile(testFileName, MEDIUM_TEST_SIZE, testString);

TransferManagerConfiguration transferManagerConfig(m_executor.get());
transferManagerConfig.s3Client = m_s3Clients[GetParam()];
auto transferManager = TransferManager::Create(transferManagerConfig);

std::shared_ptr<TransferHandle> uploadPtr = transferManager->UploadFile(testFileName, GetTestBucketName(), RandomFileName, "text/plain", Aws::Map<Aws::String, Aws::String>());
uploadPtr->WaitUntilFinished();
ASSERT_EQ(TransferStatus::COMPLETED, uploadPtr->GetStatus());
ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str()));

auto downloadFileName = MakeDownloadFileName(RandomFileName);
auto createStreamFn = [=](){
#ifdef _MSC_VER
return Aws::New<Aws::FStream>(ALLOCATION_TAG, Aws::Utils::StringUtils::ToWString(downloadFileName.c_str()).c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc);
#else
return Aws::New<Aws::FStream>(ALLOCATION_TAG, downloadFileName.c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc);
#endif
};

uint64_t offset = 1024;
uint64_t partSize = 2048;
std::shared_ptr<TransferHandle> downloadPtr = transferManager->DownloadFile(GetTestBucketName(), RandomFileName, offset, partSize, createStreamFn);

downloadPtr->WaitUntilFinished();
ASSERT_EQ(TransferStatus::COMPLETED, downloadPtr->GetStatus());
ASSERT_EQ(partSize, downloadPtr->GetBytesTotalSize());
ASSERT_EQ(partSize, downloadPtr->GetBytesTransferred());
}

INSTANTIATE_TEST_SUITE_P(Https, TransferTests, testing::Values(TestType::Https));
INSTANTIATE_TEST_SUITE_P(Http, TransferTests, testing::Values(TestType::Http));

Expand Down
Loading