Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

## 1.1.0-beta.10 (Unreleased)

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
- Added sanity check for manifest size at download time - if manifest is bigger than 4MB, `RequestFailedException` will be thrown.

## 1.1.0-beta.9 (2023-04-11)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace Azure.Containers.ContainerRegistry
internal class BlobHelper
{
private const string ClientAndServerDigestsDontMatchMessage = "The server-computed digest does not match the client-computed digest.";
internal const string ContentDigestDoesntMatchRequestedMessage = "The digest computed from the downloaded content does not match the requested digest.";
internal const string ManifestDigestDoestMatchRequestedMessage = "The digest of the received manifest does not match the requested digest reference.";

internal static string ComputeDigest(Stream stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ namespace Azure.Containers.ContainerRegistry
public class ContainerRegistryContentClient
{
private const int DefaultChunkSize = 4 * 1024 * 1024; // 4MB
private const int MaxManifestSize = 4 * 1024 * 1024;

private const string InvalidContentLengthMessage = "Missing or invalid 'Content-Length' header in the response.";
private const string InvalidContentRangeMessage = "Missing or invalid 'Content-Range' header in the response.";

private readonly Uri _endpoint;
private readonly string _registryName;
Expand Down Expand Up @@ -532,10 +536,22 @@ private static string GetContentRange(long offset, long length)
return FormattableString.Invariant($"{offset}-{endRange}");
}

private static long GetBlobLengthFromContentRange(string contentRange)
private static long GetBlobSize(Response response)
{
string size = contentRange.Split('/')[1];
return long.Parse(size, CultureInfo.InvariantCulture);
if (!response.Headers.TryGetValue("Content-Range", out string contentRange) ||
contentRange == null)
{
throw new RequestFailedException(response.Status, InvalidContentRangeMessage);
}

int index = contentRange.IndexOf('/');
if (!long.TryParse(contentRange.Substring(index + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out long size) ||
size <= 0)
{
throw new RequestFailedException(response.Status, InvalidContentRangeMessage);
}

return size;
}

// Some streams will throw if you try to access their length so we wrap
Expand Down Expand Up @@ -573,26 +589,7 @@ public virtual Response<GetManifestResult> GetManifest(string tagOrDigest, Cance
scope.Start();
try
{
string accept = GetAcceptHeader();

Response<ManifestWrapper> response = _restClient.GetManifest(_repositoryName, tagOrDigest, accept, cancellationToken);
Response rawResponse = response.GetRawResponse();

rawResponse.Headers.TryGetValue("Docker-Content-Digest", out string digest);
rawResponse.Headers.TryGetValue("Content-Type", out string contentType);

var contentDigest = BlobHelper.ComputeDigest(rawResponse.ContentStream);

if (ReferenceIsDigest(tagOrDigest))
{
BlobHelper.ValidateDigest(contentDigest, tagOrDigest, BlobHelper.ManifestDigestDoestMatchRequestedMessage);
}
else
{
BlobHelper.ValidateDigest(contentDigest, digest);
}

return Response.FromValue(new GetManifestResult(digest, contentType, rawResponse.Content), rawResponse);
return GetManifestInternalAsync(tagOrDigest, false, cancellationToken).EnsureCompleted();
}
catch (Exception e)
{
Expand All @@ -617,32 +614,39 @@ public virtual async Task<Response<GetManifestResult>> GetManifestAsync(string t
scope.Start();
try
{
string accept = GetAcceptHeader();
return await GetManifestInternalAsync(tagOrDigest, true, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
scope.Failed(e);
throw;
}
}

Response<ManifestWrapper> response = await _restClient.GetManifestAsync(_repositoryName, tagOrDigest, accept, cancellationToken).ConfigureAwait(false);
Response rawResponse = response.GetRawResponse();
private async Task<Response<GetManifestResult>> GetManifestInternalAsync(string reference, bool async, CancellationToken cancellationToken)
{
string accept = GetAcceptHeader();

rawResponse.Headers.TryGetValue("Docker-Content-Digest", out var digest);
rawResponse.Headers.TryGetValue("Content-Type", out string contentType);
Response<ManifestWrapper> response = async ?
await _restClient.GetManifestAsync(_repositoryName, reference, accept, cancellationToken).ConfigureAwait(false) :
_restClient.GetManifest(_repositoryName, reference, accept, cancellationToken);
Response rawResponse = response.GetRawResponse();

var contentDigest = BlobHelper.ComputeDigest(rawResponse.ContentStream);
CheckManifestSize(rawResponse);

if (ReferenceIsDigest(tagOrDigest))
{
BlobHelper.ValidateDigest(contentDigest, tagOrDigest, BlobHelper.ManifestDigestDoestMatchRequestedMessage);
}
else
{
BlobHelper.ValidateDigest(contentDigest, digest);
}
rawResponse.Headers.TryGetValue("Docker-Content-Digest", out string responseHeaderDigest);
rawResponse.Headers.TryGetValue("Content-Type", out string contentType);

return Response.FromValue(new GetManifestResult(digest, contentType, rawResponse.Content), rawResponse);
}
catch (Exception e)
string computedDigest = BlobHelper.ComputeDigest(rawResponse.ContentStream);

BlobHelper.ValidateDigest(computedDigest, responseHeaderDigest);

if (ReferenceIsDigest(reference))
{
scope.Failed(e);
throw;
BlobHelper.ValidateDigest(computedDigest, reference, BlobHelper.ManifestDigestDoestMatchRequestedMessage);
}

return Response.FromValue(new GetManifestResult(responseHeaderDigest, contentType, rawResponse.Content), rawResponse);
}

private static string GetAcceptHeader()
Expand Down Expand Up @@ -671,6 +675,30 @@ private static bool ReferenceIsDigest(string reference)
return reference.StartsWith("sha256:", StringComparison.OrdinalIgnoreCase);
}

private static void CheckContentLength(Response response)
{
if (response.Headers.ContentLength == null ||
response.Headers.ContentLength <= 0)
{
throw new RequestFailedException(response.Status, InvalidContentLengthMessage);
}
}

private static void CheckManifestSize(Response response)
{
// This check is to address part of the service threat model.
// If a manifest does not have a proper content length or is too big,
// it indicates a malicious or faulty service and should not be trusted.
CheckContentLength(response);

int? size = response.Headers.ContentLength;

if (size > MaxManifestSize)
{
throw new RequestFailedException(response.Status, "Manifest size is bigger than max allowed size of 4MB.");
}
}

/// <summary>
/// Download a container registry blob.
/// This API is a prefered way to fetch blobs that can fit into memory.
Expand Down Expand Up @@ -735,14 +763,17 @@ private async Task<Response<DownloadRegistryBlobResult>> DownloadBlobContentInte
await _blobRestClient.GetBlobAsync(_repositoryName, digest, cancellationToken).ConfigureAwait(false) :
_blobRestClient.GetBlob(_repositoryName, digest, cancellationToken);

Response response = blobResult.GetRawResponse();
CheckContentLength(response);

BinaryData data = async ?
await BinaryData.FromStreamAsync(blobResult.Value, cancellationToken).ConfigureAwait(false) :
BinaryData.FromStream(blobResult.Value);

string contentDigest = BlobHelper.ComputeDigest(data);
BlobHelper.ValidateDigest(contentDigest, digest);
BlobHelper.ValidateDigest(contentDigest, digest, BlobHelper.ContentDigestDoesntMatchRequestedMessage);

return Response.FromValue(new DownloadRegistryBlobResult(digest, data), blobResult.GetRawResponse());
return Response.FromValue(new DownloadRegistryBlobResult(digest, data), response);
}

/// <summary>
Expand Down Expand Up @@ -837,6 +868,9 @@ private async Task<Response<DownloadRegistryBlobStreamingResult>> DownloadBlobSt
await _blobRestClient.GetBlobAsync(_repositoryName, digest, cancellationToken).ConfigureAwait(false) :
_blobRestClient.GetBlob(_repositoryName, digest, cancellationToken);

Response response = blobResult.GetRawResponse();
CheckContentLength(response);

// Wrap the response Content in a RetriableStream so we
// can return it before it's finished downloading, but still
// allow retrying if it fails.
Expand All @@ -849,7 +883,7 @@ await _blobRestClient.GetBlobAsync(_repositoryName, digest, cancellationToken).C

ValidatingStream stream = new(retriableStream, (int)blobResult.Headers.ContentLength.Value, digest);

return Response.FromValue(new DownloadRegistryBlobStreamingResult(digest, stream), blobResult.GetRawResponse());
return Response.FromValue(new DownloadRegistryBlobStreamingResult(digest, stream), response);
}

/// <summary>
Expand Down Expand Up @@ -988,7 +1022,7 @@ private async Task<Response> DownloadBlobToInternalAsync(string digest, Stream d
using SHA256 sha256 = SHA256.Create();

long blobBytes = 0;
long? blobLength = default;
long? blobSize = default;

try
{
Expand All @@ -997,16 +1031,16 @@ private async Task<Response> DownloadBlobToInternalAsync(string digest, Stream d
do
{
// Request a chunk
long requestLength = blobLength.HasValue ?
(int)Math.Min(blobLength.Value - blobBytes, options.MaxChunkSize) :
long requestLength = blobSize.HasValue ?
(int)Math.Min(blobSize.Value - blobBytes, options.MaxChunkSize) :
options.MaxChunkSize;
string requestRange = new HttpRange(blobBytes, requestLength).ToString();

var getChunkResponse = async ?
ResponseWithHeaders<Stream, ContainerRegistryBlobGetChunkHeaders> getChunkResponse = async ?
await _blobRestClient.GetChunkAsync(_repositoryName, digest, requestRange, cancellationToken).ConfigureAwait(false) :
_blobRestClient.GetChunk(_repositoryName, digest, requestRange, cancellationToken);

blobLength ??= GetBlobLengthFromContentRange(getChunkResponse.Headers.ContentRange);
blobSize ??= GetBlobSize(getChunkResponse.GetRawResponse());

int chunkLength = (int)getChunkResponse.Headers.ContentLength.Value;
Stream responseStream = getChunkResponse.Value;
Expand Down Expand Up @@ -1037,12 +1071,12 @@ await responseStream.ReadAsync(buffer, chunkBytes, chunkLength - chunkBytes, can
blobBytes += chunkBytes;
result = getChunkResponse.GetRawResponse();
}
while (blobBytes < blobLength.Value);
while (blobBytes < blobSize.Value);

// Complete hash computation.
sha256.TransformFinalBlock(buffer, 0, 0);
string computedDigest = BlobHelper.FormatDigest(sha256.Hash);
BlobHelper.ValidateDigest(computedDigest, digest);
BlobHelper.ValidateDigest(computedDigest, digest, BlobHelper.ContentDigestDoesntMatchRequestedMessage);

if (async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private void ProcessIncrement(byte[] buffer, int offset, int length)
{
_sha256.TransformFinalBlock(Array.Empty<byte>(), 0, 0);
string computedDigest = BlobHelper.FormatDigest(_sha256.Hash);
BlobHelper.ValidateDigest(computedDigest, _digest);
BlobHelper.ValidateDigest(computedDigest, _digest, BlobHelper.ContentDigestDoesntMatchRequestedMessage);
_validated = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ public async Task CanDownloadBlobStreaming()

// Act
Response<DownloadRegistryBlobStreamingResult> downloadResult = await client.DownloadBlobStreamingAsync(digest);
Stream downloadedStream = downloadResult.Value.Content;
using Stream downloadedStream = downloadResult.Value.Content;
BinaryData content = BinaryData.FromStream(downloadedStream);

// Assert
Expand All @@ -532,7 +532,6 @@ public async Task CanDownloadBlobStreaming()
Assert.AreEqual(data, content.ToArray());

// Clean up
downloadedStream.Dispose();
await client.DeleteBlobAsync(digest);
}

Expand Down
Loading