diff --git a/sdk/core/Azure.Core/api/Azure.Core.net461.cs b/sdk/core/Azure.Core/api/Azure.Core.net461.cs index c0b60287caf0..9fe50a17e2e9 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.net461.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.net461.cs @@ -216,15 +216,18 @@ public static partial class Names { public static string Accept { get { throw null; } } public static string Authorization { get { throw null; } } + public static string ContentDisposition { get { throw null; } } public static string ContentLength { get { throw null; } } public static string ContentType { get { throw null; } } public static string Date { get { throw null; } } public static string ETag { get { throw null; } } + public static string Host { get { throw null; } } public static string IfMatch { get { throw null; } } public static string IfModifiedSince { get { throw null; } } public static string IfNoneMatch { get { throw null; } } public static string IfUnmodifiedSince { get { throw null; } } public static string Range { get { throw null; } } + public static string Referer { get { throw null; } } public static string UserAgent { get { throw null; } } public static string XMsDate { get { throw null; } } public static string XMsRange { get { throw null; } } diff --git a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs index c0b60287caf0..9fe50a17e2e9 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs @@ -216,15 +216,18 @@ public static partial class Names { public static string Accept { get { throw null; } } public static string Authorization { get { throw null; } } + public static string ContentDisposition { get { throw null; } } public static string ContentLength { get { throw null; } } public static string ContentType { get { throw null; } } public static string Date { get { throw null; } } public static string ETag { get { throw null; } } + public static string Host { get { throw null; } } public static string IfMatch { get { throw null; } } public static string IfModifiedSince { get { throw null; } } public static string IfNoneMatch { get { throw null; } } public static string IfUnmodifiedSince { get { throw null; } } public static string Range { get { throw null; } } + public static string Referer { get { throw null; } } public static string UserAgent { get { throw null; } } public static string XMsDate { get { throw null; } } public static string XMsRange { get { throw null; } } diff --git a/sdk/core/Azure.Core/src/HttpHeader.cs b/sdk/core/Azure.Core/src/HttpHeader.cs index 5f075a98a0d4..cf748bfb3f45 100644 --- a/sdk/core/Azure.Core/src/HttpHeader.cs +++ b/sdk/core/Azure.Core/src/HttpHeader.cs @@ -131,9 +131,19 @@ public static class Names /// Returns. "If-Unmodified-Since" /// public static string IfUnmodifiedSince => "If-Unmodified-Since"; + /// + /// Returns. "Referer" + /// + public static string Referer => "Referer"; + /// + /// Returns. "Host" + /// + public static string Host => "Host"; - internal static string Referer => "Referer"; - internal static string Host => "Host"; + /// + /// Returns "Content-Disposition" + /// + public static string ContentDisposition => "Content-Disposition"; } #pragma warning disable CA1034 // Nested types should not be visible diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs b/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs new file mode 100644 index 000000000000..e2f455765a02 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs @@ -0,0 +1,440 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System; +using System.Buffers; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable CA1305 // ToString Locale +#pragma warning disable CA1822 // Make member static +#pragma warning disable IDE0016 // Simplify null check +#pragma warning disable IDE0036 // Modifiers not ordered +#pragma warning disable IDE0054 // Use compound assignment +#pragma warning disable IDE0059 // Unnecessary assignment + +namespace Azure.Core +{ + /// + /// A Stream that wraps another stream and allows reading lines. + /// The data is buffered in memory. + /// + internal class BufferedReadStream : Stream + { + private const byte CR = (byte)'\r'; + private const byte LF = (byte)'\n'; + + private readonly Stream _inner; + private readonly byte[] _buffer; + private readonly ArrayPool _bytePool; + private int _bufferOffset = 0; + private int _bufferCount = 0; + private bool _disposed; + + /// + /// Creates a new stream. + /// + /// The stream to wrap. + /// Size of buffer in bytes. + public BufferedReadStream(Stream inner, int bufferSize) + : this(inner, bufferSize, ArrayPool.Shared) + { + } + + /// + /// Creates a new stream. + /// + /// The stream to wrap. + /// Size of buffer in bytes. + /// ArrayPool for the buffer. + public BufferedReadStream(Stream inner, int bufferSize, ArrayPool bytePool) + { + if (inner == null) + { + throw new ArgumentNullException(nameof(inner)); + } + + _inner = inner; + _bytePool = bytePool; + _buffer = bytePool.Rent(bufferSize); + } + + /// + /// The currently buffered data. + /// + public ArraySegment BufferedData + { + get { return new ArraySegment(_buffer, _bufferOffset, _bufferCount); } + } + + /// + public override bool CanRead + { + get { return _inner.CanRead || _bufferCount > 0; } + } + + /// + public override bool CanSeek + { + get { return _inner.CanSeek; } + } + + /// + public override bool CanTimeout + { + get { return _inner.CanTimeout; } + } + + /// + public override bool CanWrite + { + get { return _inner.CanWrite; } + } + + /// + public override long Length + { + get { return _inner.Length; } + } + + /// + public override long Position + { + get { return _inner.Position - _bufferCount; } + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), value, "Position must be positive."); + } + if (value == Position) + { + return; + } + + // Backwards? + if (value <= _inner.Position) + { + // Forward within the buffer? + var innerOffset = (int)(_inner.Position - value); + if (innerOffset <= _bufferCount) + { + // Yes, just skip some of the buffered data + _bufferOffset += innerOffset; + _bufferCount -= innerOffset; + } + else + { + // No, reset the buffer + _bufferOffset = 0; + _bufferCount = 0; + _inner.Position = value; + } + } + else + { + // Forward, reset the buffer + _bufferOffset = 0; + _bufferCount = 0; + _inner.Position = value; + } + } + } + + /// + public override long Seek(long offset, SeekOrigin origin) + { + if (origin == SeekOrigin.Begin) + { + Position = offset; + } + else if (origin == SeekOrigin.Current) + { + Position = Position + offset; + } + else // if (origin == SeekOrigin.End) + { + Position = Length + offset; + } + return Position; + } + + /// + public override void SetLength(long value) + { + _inner.SetLength(value); + } + + /// + protected override void Dispose(bool disposing) + { + if (!_disposed) + { + _disposed = true; + _bytePool.Return(_buffer); + + if (disposing) + { + _inner.Dispose(); + } + } + } + + /// + public override void Flush() + { + _inner.Flush(); + } + + /// + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _inner.FlushAsync(cancellationToken); + } + + /// + public override void Write(byte[] buffer, int offset, int count) + { + _inner.Write(buffer, offset, count); + } + + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _inner.WriteAsync(buffer, offset, count, cancellationToken); + } + + /// + public override int Read(byte[] buffer, int offset, int count) + { + ValidateBuffer(buffer, offset, count); + + // Drain buffer + if (_bufferCount > 0) + { + int toCopy = Math.Min(_bufferCount, count); + Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy); + _bufferOffset += toCopy; + _bufferCount -= toCopy; + return toCopy; + } + + return _inner.Read(buffer, offset, count); + } + + /// + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBuffer(buffer, offset, count); + + // Drain buffer + if (_bufferCount > 0) + { + int toCopy = Math.Min(_bufferCount, count); + Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy); + _bufferOffset += toCopy; + _bufferCount -= toCopy; + return toCopy; + } + + return await _inner.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + + /// + /// Ensures that the buffer is not empty. + /// + /// Returns true if the buffer is not empty; false otherwise. + public bool EnsureBuffered() + { + if (_bufferCount > 0) + { + return true; + } + // Downshift to make room + _bufferOffset = 0; + _bufferCount = _inner.Read(_buffer, 0, _buffer.Length); + return _bufferCount > 0; + } + + /// + /// Ensures that the buffer is not empty. + /// + /// Cancellation token. + /// Returns true if the buffer is not empty; false otherwise. + public async Task EnsureBufferedAsync(CancellationToken cancellationToken) + { + if (_bufferCount > 0) + { + return true; + } + // Downshift to make room + _bufferOffset = 0; + _bufferCount = await _inner.ReadAsync(_buffer, 0, _buffer.Length, cancellationToken).ConfigureAwait(false); + return _bufferCount > 0; + } + + /// + /// Ensures that a minimum amount of buffered data is available. + /// + /// Minimum amount of buffered data. + /// Returns true if the minimum amount of buffered data is available; false otherwise. + public bool EnsureBuffered(int minCount) + { + if (minCount > _buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length.ToString()); + } + while (_bufferCount < minCount) + { + // Downshift to make room + if (_bufferOffset > 0) + { + if (_bufferCount > 0) + { + Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount); + } + _bufferOffset = 0; + } + int read = _inner.Read(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset); + _bufferCount += read; + if (read == 0) + { + return false; + } + } + return true; + } + + /// + /// Ensures that a minimum amount of buffered data is available. + /// + /// Minimum amount of buffered data. + /// Cancellation token. + /// Returns true if the minimum amount of buffered data is available; false otherwise. + public async Task EnsureBufferedAsync(int minCount, CancellationToken cancellationToken) + { + if (minCount > _buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length.ToString()); + } + while (_bufferCount < minCount) + { + // Downshift to make room + if (_bufferOffset > 0) + { + if (_bufferCount > 0) + { + Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount); + } + _bufferOffset = 0; + } + int read = await _inner.ReadAsync(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset, cancellationToken).ConfigureAwait(false); + _bufferCount += read; + if (read == 0) + { + return false; + } + } + return true; + } + + /// + /// Reads a line. A line is defined as a sequence of characters followed by + /// a carriage return immediately followed by a line feed. The resulting string does not + /// contain the terminating carriage return and line feed. + /// + /// Maximum allowed line length. + /// A line. + public string ReadLine(int lengthLimit) + { + CheckDisposed(); + using (var builder = new MemoryStream(200)) + { + bool foundCR = false, foundCRLF = false; + + while (!foundCRLF && EnsureBuffered()) + { + if (builder.Length > lengthLimit) + { + throw new InvalidDataException($"Line length limit {lengthLimit} exceeded."); + } + ProcessLineChar(builder, ref foundCR, ref foundCRLF); + } + + return DecodeLine(builder, foundCRLF); + } + } + + /// + /// Reads a line. A line is defined as a sequence of characters followed by + /// a carriage return immediately followed by a line feed. The resulting string does not + /// contain the terminating carriage return and line feed. + /// + /// Maximum allowed line length. + /// Cancellation token. + /// A line. + public async Task ReadLineAsync(int lengthLimit, CancellationToken cancellationToken) + { + CheckDisposed(); + using (var builder = new MemoryStream(200)) + { + bool foundCR = false, foundCRLF = false; + + while (!foundCRLF && await EnsureBufferedAsync(cancellationToken).ConfigureAwait(false)) + { + if (builder.Length > lengthLimit) + { + throw new InvalidDataException($"Line length limit {lengthLimit} exceeded."); + } + + ProcessLineChar(builder, ref foundCR, ref foundCRLF); + } + + return DecodeLine(builder, foundCRLF); + } + } + + private void ProcessLineChar(MemoryStream builder, ref bool foundCR, ref bool foundCRLF) + { + var b = _buffer[_bufferOffset]; + builder.WriteByte(b); + _bufferOffset++; + _bufferCount--; + if (b == LF && foundCR) + { + foundCRLF = true; + return; + } + foundCR = b == CR; + } + + private string DecodeLine(MemoryStream builder, bool foundCRLF) + { + // Drop the final CRLF, if any + var length = foundCRLF ? builder.Length - 2 : builder.Length; + return Encoding.UTF8.GetString(builder.ToArray(), 0, (int)length); + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(BufferedReadStream)); + } + } + + private void ValidateBuffer(byte[] buffer, int offset, int count) + { + // Delegate most of our validation. + var ignored = new ArraySegment(buffer, offset, count); + if (count == 0) + { + throw new ArgumentOutOfRangeException(nameof(count), "The value must be greater than zero."); + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs b/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs new file mode 100644 index 000000000000..2db059026d77 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System; +using System.Collections.Generic; + +#pragma warning disable IDE0008 // Use explicit type +#pragma warning disable IDE0018 // Inline declaration +#pragma warning disable IDE0034 // default can be simplified + +namespace Azure.Core +{ + internal struct KeyValueAccumulator + { + private Dictionary _accumulator; + private Dictionary> _expandingAccumulator; + + public void Append(string key, string value) + { + if (_accumulator == null) + { + _accumulator = new Dictionary(StringComparer.OrdinalIgnoreCase); + } + + string[] values; + if (_accumulator.TryGetValue(key, out values)) + { + if (values.Length == 0) + { + // Marker entry for this key to indicate entry already in expanding list dictionary + _expandingAccumulator[key].Add(value); + } + else if (values.Length == 1) + { + // Second value for this key + _accumulator[key] = new string[] { values[0], value }; + } + else + { + // Third value for this key + // Add zero count entry and move to data to expanding list dictionary + _accumulator[key] = null; + + if (_expandingAccumulator == null) + { + _expandingAccumulator = new Dictionary>(StringComparer.OrdinalIgnoreCase); + } + + // Already 3 entries so use starting allocated as 8; then use List's expansion mechanism for more + var list = new List(8); + + list.Add(values[0]); + list.Add(values[1]); + list.Add(value); + + _expandingAccumulator[key] = list; + } + } + else + { + // First value for this key + _accumulator[key] = new[] { value }; + } + + ValueCount++; + } + + public bool HasValues => ValueCount > 0; + + public int KeyCount => _accumulator?.Count ?? 0; + + public int ValueCount { get; private set; } + + public Dictionary GetResults() + { + if (_expandingAccumulator != null) + { + // Coalesce count 3+ multi-value entries into _accumulator dictionary + foreach (var entry in _expandingAccumulator) + { + _accumulator[entry.Key] = entry.Value.ToArray(); + } + } + + return _accumulator ?? new Dictionary(0, StringComparer.OrdinalIgnoreCase); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs new file mode 100644 index 000000000000..b63f440ece2e --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +#nullable disable + +namespace Azure.Core +{ + /// + /// A Response that can be constructed in memory without being tied to a + /// live request. + /// + internal class MemoryResponse : Response + { + private const int NoStatusCode = 0; + private const string XmsClientRequestIdName = "x-ms-client-request-id"; + + /// + /// The Response . + /// + private int _status = NoStatusCode; + + /// + /// The Response . + /// + private string _reasonPhrase = null; + + /// + /// The . + /// + private readonly IDictionary> _headers = + new Dictionary>(StringComparer.OrdinalIgnoreCase); + + /// + public override int Status => _status; + + /// + public override string ReasonPhrase => _reasonPhrase; + + /// + public override Stream ContentStream { get; set; } + + /// + public override string ClientRequestId + { + get => TryGetHeader(XmsClientRequestIdName, out string id) ? id : null; + set => SetHeader(XmsClientRequestIdName, value); + } + + /// + /// Set the Response . + /// + /// The Response status. + public void SetStatus(int status) => _status = status; + + /// + /// Set the Response . + /// + /// The Response ReasonPhrase. + public void SetReasonPhrase(string reasonPhrase) => _reasonPhrase = reasonPhrase; + + /// + /// Set the Response . + /// + /// The response content. + public void SetContent(byte[] content) => ContentStream = new MemoryStream(content); + + /// + /// Set the Response . + /// + /// The response content. + public void SetContent(string content) => SetContent(Encoding.UTF8.GetBytes(content)); + + /// + /// Dispose the Response. + /// + public override void Dispose() => ContentStream?.Dispose(); + + /// + /// Set the value of a response header (and overwrite any existing + /// values). + /// + /// The name of the response header. + /// The response header value. + public void SetHeader(string name, string value) => + SetHeader(name, new List { value }); + + /// + /// Set the values of a response header (and overwrite any existing + /// values). + /// + /// The name of the response header. + /// The response header values. + public void SetHeader(string name, IEnumerable values) => + _headers[name] = values.ToList(); + + /// + /// Add a response header value. + /// + /// The name of the response header. + /// The response header value. + public void AddHeader(string name, string value) + { + if (!_headers.TryGetValue(name, out List values)) + { + _headers[name] = values = new List(); + } + values.Add(value); + } + + /// +#if HAS_INTERNALS_VISIBLE_CORE + internal +#endif + protected override bool ContainsHeader(string name) => + _headers.ContainsKey(name); + + /// +#if HAS_INTERNALS_VISIBLE_CORE + internal +#endif + protected override IEnumerable EnumerateHeaders() => + _headers.Select(header => new HttpHeader(header.Key, JoinHeaderValues(header.Value))); + + /// +#if HAS_INTERNALS_VISIBLE_CORE + internal +#endif + protected override bool TryGetHeader(string name, out string value) + { + if (_headers.TryGetValue(name, out List headers)) + { + value = JoinHeaderValues(headers); + return true; + } + value = null; + return false; + } + + /// +#if HAS_INTERNALS_VISIBLE_CORE + internal +#endif + protected override bool TryGetHeaderValues(string name, out IEnumerable values) + { + bool found = _headers.TryGetValue(name, out List headers); + values = headers; + return found; + } + + /// + /// Join multiple header values together with a comma. + /// + /// The header values. + /// A single joined value. + private static string JoinHeaderValues(IEnumerable values) => + string.Join(",", values); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs b/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs new file mode 100644 index 000000000000..6047d8bb7f13 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#nullable disable + +namespace Azure.Core +{ + /// + /// Provides support for creating and parsing multipart/mixed content. + /// This is implementing a couple of layered standards as mentioned at + /// https://docs.microsoft.com/en-us/rest/api/storageservices/blob-batch and + /// https://docs.microsoft.com/en-us/rest/api/storageservices/performing-entity-group-transactions + /// including https://www.odata.org/documentation/odata-version-3-0/batch-processing/ + /// and https://www.ietf.org/rfc/rfc2046.txt. + /// + internal static class Multipart + { + private const int KB = 1024; + private const int ResponseLineSize = 4 * KB; + private const string MultipartContentTypePrefix = "multipart/mixed; boundary="; + private const string ContentIdName = "Content-ID"; + internal static InvalidOperationException InvalidBatchContentType(string contentType) => + new InvalidOperationException($"Expected {HttpHeader.Names.ContentType} to start with {MultipartContentTypePrefix} but received {contentType}"); + + internal static InvalidOperationException InvalidHttpStatusLine(string statusLine) => + new InvalidOperationException($"Expected an HTTP status line, not {statusLine}"); + + internal static InvalidOperationException InvalidHttpHeaderLine(string headerLine) => + new InvalidOperationException($"Expected an HTTP header line, not {headerLine}"); + + /// + /// Parse a multipart/mixed response body into several responses. + /// + /// The response content. + /// The response content type. + /// + /// Whether to invoke the operation asynchronously. + /// + /// + /// Optional to propagate notifications + /// that the operation should be cancelled. + /// + /// The parsed s. + internal static async Task ParseAsync( + Stream batchContent, + string batchContentType, + bool async, + CancellationToken cancellationToken) + { + // Get the batch boundary + if (!GetBoundary(batchContentType, out string batchBoundary)) + { + throw InvalidBatchContentType(batchContentType); + } + + // Collect the responses in a dictionary (in case the Content-ID + // values come back out of order) + Dictionary responses = new Dictionary(); + + // Collect responses without a Content-ID in a List + List responsesWithoutId = new List(); + + // Read through the batch body one section at a time until the + // reader returns null + MultipartReader reader = new MultipartReader(batchBoundary, batchContent); + for (MultipartSection section = await reader.GetNextSectionAsync(async, cancellationToken).ConfigureAwait(false); + section != null; + section = await reader.GetNextSectionAsync(async, cancellationToken).ConfigureAwait(false)) + { + bool contentIdFound = true; + if (section.Headers.TryGetValue(HttpHeader.Names.ContentType, out string [] contentTypeValues) && + contentTypeValues.Length == 1 && + GetBoundary(contentTypeValues[0], out string subBoundary)) + { + reader = new MultipartReader(subBoundary, section.Body); + continue; + } + // Get the Content-ID header + if (!section.Headers.TryGetValue(ContentIdName, out string [] contentIdValues) || + contentIdValues.Length != 1 || + !int.TryParse(contentIdValues[0], out int contentId)) + { + // If the header wasn't found, this is either: + // - a failed request with the details being sent as the first sub-operation + // - a tables batch request, which does not utilize Content-ID headers + // so we default the Content-ID to 0 and track that no Content-ID was found. + contentId = 0; + contentIdFound = false; + } + + // Build a response + MemoryResponse response = new MemoryResponse(); + if (contentIdFound) + { + // track responses by Content-ID + responses[contentId] = response; + } + else + { + // track responses without a Content-ID + responsesWithoutId.Add(response); + } + + // We're going to read the section's response body line by line + using var body = new BufferedReadStream(section.Body, ResponseLineSize); + + // The first line is the status like "HTTP/1.1 202 Accepted" + string line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false); + string[] status = line.Split(new char[] { ' ' }, 3, StringSplitOptions.RemoveEmptyEntries); + if (status.Length != 3) + { + throw InvalidHttpStatusLine(line); + } + response.SetStatus(int.Parse(status[1], CultureInfo.InvariantCulture)); + response.SetReasonPhrase(status[2]); + + // Continue reading headers until we reach a blank line + line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false); + while (!string.IsNullOrEmpty(line)) + { + // Split the header into the name and value + int splitIndex = line.IndexOf(':'); + if (splitIndex <= 0) + { + throw InvalidHttpHeaderLine(line); + } + var name = line.Substring(0, splitIndex); + var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim(); + response.AddHeader(name, value); + + line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false); + } + + // Copy the rest of the body as the response content + var responseContent = new MemoryStream(); + if (async) + { + await body.CopyToAsync(responseContent).ConfigureAwait(false); + } + else + { + body.CopyTo(responseContent); + } + responseContent.Seek(0, SeekOrigin.Begin); + response.ContentStream = responseContent; + } + + // Collect the responses and order by Content-ID, when available. + // Otherwise collect them as we received them. + Response[] ordered = new Response[responses.Count + responsesWithoutId.Count]; + for (int i = 0; i < responses.Count; i++) + { + ordered[i] = responses[i]; + } + for (int i = responses.Count; i < ordered.Length; i++) + { + ordered[i] = responsesWithoutId[i - responses.Count]; + } + return ordered; + } + + + /// + /// Read the next line of text. + /// + /// The stream to read from. + /// + /// Whether to invoke the operation asynchronously. + /// + /// + /// Optional to propagate notifications + /// that the operation should be cancelled. + /// + /// The next line of text. + internal static async Task ReadLineAsync( + this BufferedReadStream stream, + bool async, + CancellationToken cancellationToken) => + async ? + await stream.ReadLineAsync(ResponseLineSize, cancellationToken).ConfigureAwait(false) : + stream.ReadLine(ResponseLineSize); + + /// + /// Read the next multipart section. + /// + /// The reader to parse with. + /// + /// Whether to invoke the operation asynchronously. + /// + /// + /// Optional to propagate notifications + /// that the operation should be cancelled. + /// + /// The next multipart section. + internal static async Task GetNextSectionAsync( + this MultipartReader reader, + bool async, + CancellationToken cancellationToken) => + async ? + await reader.ReadNextSectionAsync(cancellationToken).ConfigureAwait(false) : +#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. + reader.ReadNextSectionAsync(cancellationToken).GetAwaiter().GetResult(); // #7972: Decide if we need a proper sync API here +#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. + + private static bool GetBoundary(string contentType, out string batchBoundary) + { + if (contentType == null || !contentType.StartsWith(MultipartContentTypePrefix, StringComparison.Ordinal)) + { + batchBoundary = null; + return false; + } + batchBoundary = contentType.Substring(MultipartContentTypePrefix.Length); + return true; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs new file mode 100644 index 000000000000..bba72d47719b --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System; +using System.Text; + +namespace Azure.Core +{ + internal class MultipartBoundary + { + private readonly int[] _skipTable = new int[256]; + private readonly string _boundary; + private bool _expectLeadingCrlf; + + public MultipartBoundary(string boundary, bool expectLeadingCrlf = true) + { + if (boundary == null) + { + throw new ArgumentNullException(nameof(boundary)); + } + + _boundary = boundary; + _expectLeadingCrlf = expectLeadingCrlf; + Initialize(_boundary, _expectLeadingCrlf); + } + + private void Initialize(string boundary, bool expectLeadingCrlf) + { + if (expectLeadingCrlf) + { + BoundaryBytes = Encoding.UTF8.GetBytes("\r\n--" + boundary); + } + else + { + BoundaryBytes = Encoding.UTF8.GetBytes("--" + boundary); + } + FinalBoundaryLength = BoundaryBytes.Length + 2; // Include the final '--' terminator. + + var length = BoundaryBytes.Length; + for (var i = 0; i < _skipTable.Length; ++i) + { + _skipTable[i] = length; + } + for (var i = 0; i < length; ++i) + { + _skipTable[BoundaryBytes[i]] = Math.Max(1, length - 1 - i); + } + } + + public int GetSkipValue(byte input) + { + return _skipTable[input]; + } + + public bool ExpectLeadingCrlf + { + get { return _expectLeadingCrlf; } + set + { + if (value != _expectLeadingCrlf) + { + _expectLeadingCrlf = value; + Initialize(_boundary, _expectLeadingCrlf); + } + } + } + + public byte[] BoundaryBytes { get; private set; } = default!; // This gets initialized as part of Initialize called from in the ctor. + + public int FinalBoundaryLength { get; private set; } + } +} diff --git a/sdk/core/Azure.Core/src/MultipartFormDataContent.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs similarity index 80% rename from sdk/core/Azure.Core/src/MultipartFormDataContent.cs rename to sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs index e83bfc2d567b..3604f62cb2a6 100644 --- a/sdk/core/Azure.Core/src/MultipartFormDataContent.cs +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; @@ -8,43 +8,58 @@ using System.Threading; using System.Threading.Tasks; +#nullable disable + namespace Azure.Core { /// /// Provides a container for content encoded using multipart/form-data MIME type. /// - internal class MultipartFormDataContent : RequestContent + internal class MultipartContent : RequestContent { #region Fields private const string CrLf = "\r\n"; - private const string FormData = "form-data"; + private const string ColonSP = ": "; private static readonly int s_crlfLength = GetEncodedLength(CrLf); private static readonly int s_dashDashLength = GetEncodedLength("--"); - private static readonly int s_colonSpaceLength = GetEncodedLength(": "); + private static readonly int s_colonSpaceLength = GetEncodedLength(ColonSP); private readonly List _nestedContent; + private readonly string _subtype; private readonly string _boundary; + internal readonly Dictionary _headers; #endregion Fields #region Construction - /// - /// Initializes a new instance of the class. - /// - public MultipartFormDataContent() : this(GetDefaultBoundary()) + public MultipartContent() + : this("mixed", GetDefaultBoundary()) + { } + + public MultipartContent(string subtype) + : this(subtype, GetDefaultBoundary()) { } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// + /// The multipart sub type. /// The boundary string for the multipart form data content. - public MultipartFormDataContent(string boundary) + public MultipartContent(string subtype, string boundary) { ValidateBoundary(boundary); - _boundary = boundary; + _subtype = subtype; + + // see https://www.ietf.org/rfc/rfc1521.txt page 29. + _boundary = boundary.Contains(":") ? $"\"{boundary}\"" : boundary; + _headers = new Dictionary + { + [HttpHeader.Names.ContentType] = $"multipart/{_subtype}; boundary={_boundary}" + }; + _nestedContent = new List(); } @@ -97,7 +112,7 @@ private static string GetDefaultBoundary() /// The request. public void ApplyToRequest(Request request) { - request.Headers.Add("Content-Type", $"multipart/form-data;boundary=\"{_boundary}\""); + request.Headers.Add(HttpHeader.Names.ContentType, $"multipart/{_subtype}; boundary={_boundary}"); } /// @@ -105,10 +120,10 @@ public void ApplyToRequest(Request request) /// get serialized to multipart/form-data MIME type. /// /// The Request content to add to the collection. - public void Add(RequestContent content) + public virtual void Add(RequestContent content) { Argument.AssertNotNull(content, nameof(content)); - AddInternal(content, null, null, null); + AddInternal(content, null); } /// @@ -117,69 +132,20 @@ public void Add(RequestContent content) /// /// The Request content to add to the collection. /// The headers to add to the collection. - public void Add(RequestContent content, Dictionary headers) + public virtual void Add(RequestContent content, Dictionary headers) { Argument.AssertNotNull(content, nameof(content)); Argument.AssertNotNull(headers, nameof(headers)); - AddInternal(content, headers, null, null); + AddInternal(content, headers); } - /// - /// Add HTTP content to a collection of RequestContent objects that - /// get serialized to multipart/form-data MIME type. - /// - /// The Request content to add to the collection. - /// The name for the request content to add. - /// The headers to add to the collection. - public void Add(RequestContent content, string name, Dictionary? headers) - { - Argument.AssertNotNull(content, nameof(content)); - Argument.AssertNotNullOrWhiteSpace(name, nameof(name)); - - AddInternal(content, headers, name, null); - } - - /// - /// Add HTTP content to a collection of RequestContent objects that - /// get serialized to multipart/form-data MIME type. - /// - /// The Request content to add to the collection. - /// The name for the request content to add. - /// The file name for the reuest content to add to the collection. - /// The headers to add to the collection. - public void Add(RequestContent content, string name, string fileName, Dictionary? headers) - { - Argument.AssertNotNull(content, nameof(content)); - Argument.AssertNotNullOrWhiteSpace(name, nameof(name)); - Argument.AssertNotNullOrWhiteSpace(fileName, nameof(fileName)); - - AddInternal(content, headers, name, fileName); - } - - private void AddInternal(RequestContent content, Dictionary? headers, string? name, string? fileName) + private void AddInternal(RequestContent content, Dictionary headers) { if (headers == null) { headers = new Dictionary(); } - - if (!headers.ContainsKey("Content-Disposition")) - { - var value = FormData; - - if (name != null) - { - value = value + "; name=" + name; - } - if (fileName != null) - { - value = value + "; filename=" + fileName; - } - - headers.Add("Content-Disposition", value); - } - _nestedContent.Add(new MultipartRequestContent(content, headers)); } @@ -188,7 +154,7 @@ private void AddInternal(RequestContent content, Dictionary? hea #region Dispose /// - /// Frees resources held by the object. + /// Frees resources held by the object. /// public override void Dispose() { @@ -197,7 +163,6 @@ public override void Dispose() content.RequestContent.Dispose(); } _nestedContent.Clear(); - } #endregion Dispose diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs new file mode 100644 index 000000000000..b99ee6ef83a6 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; + +#nullable disable + +namespace Azure.Core +{ + /// + /// Provides a container for content encoded using multipart/form-data MIME type. + /// + internal class MultipartFormDataContent : MultipartContent + { + #region Fields + + private const string FormData = "form-data"; + + #endregion Fields + + #region Construction + + /// + /// Initializes a new instance of the class. + /// + public MultipartFormDataContent() : base(FormData) + { } + + /// + /// Initializes a new instance of the class. + /// + /// The boundary string for the multipart form data content. + public MultipartFormDataContent(string boundary) : base(FormData, boundary) + { } + + #endregion Construction + + /// + /// Add HTTP content to a collection of RequestContent objects that + /// get serialized to multipart/form-data MIME type. + /// + /// The Request content to add to the collection. + public override void Add(RequestContent content) + { + Argument.AssertNotNull(content, nameof(content)); + AddInternal(content, null, null, null); + } + + /// + /// Add HTTP content to a collection of RequestContent objects that + /// get serialized to multipart/form-data MIME type. + /// + /// The Request content to add to the collection. + /// The headers to add to the collection. + public override void Add(RequestContent content, Dictionary headers) + { + Argument.AssertNotNull(content, nameof(content)); + Argument.AssertNotNull(headers, nameof(headers)); + + AddInternal(content, headers, null, null); + } + + /// + /// Add HTTP content to a collection of RequestContent objects that + /// get serialized to multipart/form-data MIME type. + /// + /// The Request content to add to the collection. + /// The name for the request content to add. + /// The headers to add to the collection. + public void Add(RequestContent content, string name, Dictionary headers) + { + Argument.AssertNotNull(content, nameof(content)); + Argument.AssertNotNullOrWhiteSpace(name, nameof(name)); + + AddInternal(content, headers, name, null); + } + + /// + /// Add HTTP content to a collection of RequestContent objects that + /// get serialized to multipart/form-data MIME type. + /// + /// The Request content to add to the collection. + /// The name for the request content to add. + /// The file name for the request content to add to the collection. + /// The headers to add to the collection. + public void Add(RequestContent content, string name, string fileName, Dictionary headers) + { + Argument.AssertNotNull(content, nameof(content)); + Argument.AssertNotNullOrWhiteSpace(name, nameof(name)); + Argument.AssertNotNullOrWhiteSpace(fileName, nameof(fileName)); + + AddInternal(content, headers, name, fileName); + } + + private void AddInternal(RequestContent content, Dictionary headers, string name, string fileName) + { + if (headers == null) + { + headers = new Dictionary(); + } + + if (!headers.ContainsKey("Content-Disposition")) + { + var value = FormData; + + if (name != null) + { + value = value + "; name=" + name; + } + if (fileName != null) + { + value = value + "; filename=" + fileName; + } + + headers.Add("Content-Disposition", value); + } + + base.Add(content, headers); + } + + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs new file mode 100644 index 000000000000..6ac52f6aec03 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable CA1001 // disposable fields +#pragma warning disable IDE0008 // Use explicit type +#nullable disable + +namespace Azure.Core +{ + // https://www.ietf.org/rfc/rfc2046.txt + internal class MultipartReader + { + public const int DefaultHeadersCountLimit = 16; + public const int DefaultHeadersLengthLimit = 1024 * 16; + private const int DefaultBufferSize = 1024 * 4; + + private readonly BufferedReadStream _stream; + private readonly MultipartBoundary _boundary; + private MultipartReaderStream _currentStream; + + public MultipartReader(string boundary, Stream stream) + : this(boundary, stream, DefaultBufferSize) + { + } + + public MultipartReader(string boundary, Stream stream, int bufferSize) + { + if (boundary == null) + { + throw new ArgumentNullException(nameof(boundary)); + } + + if (stream == null) + { + throw new ArgumentNullException(nameof(stream)); + } + + if (bufferSize < boundary.Length + 8) // Size of the boundary + leading and trailing CRLF + leading and trailing '--' markers. + { + throw new ArgumentOutOfRangeException(nameof(bufferSize), bufferSize, "Insufficient buffer space, the buffer must be larger than the boundary: " + boundary); + } + _stream = new BufferedReadStream(stream, bufferSize); + _boundary = new MultipartBoundary(boundary, false); + // This stream will drain any preamble data and remove the first boundary marker. + // TODO: HeadersLengthLimit can't be modified until after the constructor. + _currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = HeadersLengthLimit }; + } + + /// + /// The limit for the number of headers to read. + /// + public int HeadersCountLimit { get; set; } = DefaultHeadersCountLimit; + + /// + /// The combined size limit for headers per multipart section. + /// + public int HeadersLengthLimit { get; set; } = DefaultHeadersLengthLimit; + + /// + /// The optional limit for the total response body length. + /// + public long? BodyLengthLimit { get; set; } + + public async Task ReadNextSectionAsync(CancellationToken cancellationToken = new CancellationToken()) + { + // Drain the prior section. + await _currentStream.DrainAsync(cancellationToken).ConfigureAwait(false); + // If we're at the end return null + if (_currentStream.FinalBoundaryFound) + { + // There may be trailer data after the last boundary. + await _stream.DrainAsync(HeadersLengthLimit, cancellationToken).ConfigureAwait(false); + return null; + } + var headers = await ReadHeadersAsync(cancellationToken).ConfigureAwait(false); + _boundary.ExpectLeadingCrlf = true; + _currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = BodyLengthLimit }; + long? baseStreamOffset = _stream.CanSeek ? (long?)_stream.Position : null; + return new MultipartSection() { Headers = headers, Body = _currentStream, BaseStreamOffset = baseStreamOffset }; + } + + private async Task> ReadHeadersAsync(CancellationToken cancellationToken) + { + int totalSize = 0; + var accumulator = new KeyValueAccumulator(); + var line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken).ConfigureAwait(false); + while (!string.IsNullOrEmpty(line)) + { + if (HeadersLengthLimit - totalSize < line.Length) + { + throw new InvalidDataException($"Multipart headers length limit {HeadersLengthLimit} exceeded."); + } + totalSize += line.Length; + int splitIndex = line.IndexOf(':'); + if (splitIndex <= 0) + { + throw new InvalidDataException($"Invalid header line: {line}"); + } + + var name = line.Substring(0, splitIndex); + var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim(); + accumulator.Append(name, value); + if (accumulator.KeyCount > HeadersCountLimit) + { + throw new InvalidDataException($"Multipart headers count limit {HeadersCountLimit} exceeded."); + } + + line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken).ConfigureAwait(false); + } + + return accumulator.GetResults(); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs new file mode 100644 index 000000000000..157200bc2ed7 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs @@ -0,0 +1,348 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable IDE0008 // Use explicit type +#pragma warning disable IDE0016 // Simplify null check +#pragma warning disable IDE0018 // Inline declaration +#pragma warning disable IDE0054 // Use compound assignment +#pragma warning disable IDE0059 // Unnecessary assignment + +namespace Azure.Core +{ + internal sealed class MultipartReaderStream : Stream + { + private readonly MultipartBoundary _boundary; + private readonly BufferedReadStream _innerStream; + private readonly ArrayPool _bytePool; + + private readonly long _innerOffset; + private long _position; + private long _observedLength; + private bool _finished; + + /// + /// Creates a stream that reads until it reaches the given boundary pattern. + /// + /// The . + /// The boundary pattern to use. + public MultipartReaderStream(BufferedReadStream stream, MultipartBoundary boundary) + : this(stream, boundary, ArrayPool.Shared) + { + } + + /// + /// Creates a stream that reads until it reaches the given boundary pattern. + /// + /// The . + /// The boundary pattern to use. + /// The ArrayPool pool to use for temporary byte arrays. + public MultipartReaderStream(BufferedReadStream stream, MultipartBoundary boundary, ArrayPool bytePool) + { + if (stream == null) + { + throw new ArgumentNullException(nameof(stream)); + } + + if (boundary == null) + { + throw new ArgumentNullException(nameof(boundary)); + } + + _bytePool = bytePool; + _innerStream = stream; + _innerOffset = _innerStream.CanSeek ? _innerStream.Position : 0; + _boundary = boundary; + } + + public bool FinalBoundaryFound { get; private set; } + + public long? LengthLimit { get; set; } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return _innerStream.CanSeek; } + } + + public override bool CanWrite + { + get { return false; } + } + + public override long Length + { + get { return _observedLength; } + } + + public override long Position + { + get { return _position; } + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), value, "The Position must be positive."); + } + if (value > _observedLength) + { + throw new ArgumentOutOfRangeException(nameof(value), value, "The Position must be less than length."); + } + _position = value; + if (_position < _observedLength) + { + _finished = false; + } + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + if (origin == SeekOrigin.Begin) + { + Position = offset; + } + else if (origin == SeekOrigin.Current) + { + Position = Position + offset; + } + else // if (origin == SeekOrigin.End) + { + Position = Length + offset; + } + return Position; + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + private void PositionInnerStream() + { + if (_innerStream.CanSeek && _innerStream.Position != (_innerOffset + _position)) + { + _innerStream.Position = _innerOffset + _position; + } + } + + private int UpdatePosition(int read) + { + _position += read; + if (_observedLength < _position) + { + _observedLength = _position; + if (LengthLimit.HasValue && _observedLength > LengthLimit.GetValueOrDefault()) + { + throw new InvalidDataException($"Multipart body length limit {LengthLimit.GetValueOrDefault()} exceeded."); + } + } + return read; + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_finished) + { + return 0; + } + + PositionInnerStream(); + if (!_innerStream.EnsureBuffered(_boundary.FinalBoundaryLength)) + { + throw new IOException("Unexpected end of Stream, the content may have already been read by another component. "); + } + var bufferedData = _innerStream.BufferedData; + + // scan for a boundary match, full or partial. + int read; + if (SubMatch(bufferedData, _boundary.BoundaryBytes, out var matchOffset, out var matchCount)) + { + // We found a possible match, return any data before it. + if (matchOffset > bufferedData.Offset) + { + read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset)); + return UpdatePosition(read); + } + + var length = _boundary.BoundaryBytes.Length; + Debug.Assert(matchCount == length); + + // "The boundary may be followed by zero or more characters of + // linear whitespace. It is then terminated by either another CRLF" + // or -- for the final boundary. + var boundary = _bytePool.Rent(length); + read = _innerStream.Read(boundary, 0, length); + _bytePool.Return(boundary); + Debug.Assert(read == length); // It should have all been buffered + + var remainder = _innerStream.ReadLine(lengthLimit: 100); // Whitespace may exceed the buffer. + remainder = remainder.Trim(); + if (string.Equals("--", remainder, StringComparison.Ordinal)) + { + FinalBoundaryFound = true; + } +#pragma warning disable CA1820 // Test for empty strings using string length + Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder); +#pragma warning restore CA1820 // Test for empty strings using string length + _finished = true; + return 0; + } + + // No possible boundary match within the buffered data, return the data from the buffer. + read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count)); + return UpdatePosition(read); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_finished) + { + return 0; + } + + PositionInnerStream(); + if (!await _innerStream.EnsureBufferedAsync(_boundary.FinalBoundaryLength, cancellationToken).ConfigureAwait(false)) + { + throw new IOException("Unexpected end of Stream, the content may have already been read by another component. "); + } + var bufferedData = _innerStream.BufferedData; + + // scan for a boundary match, full or partial. + int matchOffset; + int matchCount; + int read; + if (SubMatch(bufferedData, _boundary.BoundaryBytes, out matchOffset, out matchCount)) + { + // We found a possible match, return any data before it. + if (matchOffset > bufferedData.Offset) + { + // Sync, it's already buffered + read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset)); + return UpdatePosition(read); + } + + var length = _boundary.BoundaryBytes!.Length; + Debug.Assert(matchCount == length); + + // "The boundary may be followed by zero or more characters of + // linear whitespace. It is then terminated by either another CRLF" + // or -- for the final boundary. + var boundary = _bytePool.Rent(length); + read = _innerStream.Read(boundary, 0, length); + _bytePool.Return(boundary); + Debug.Assert(read == length); // It should have all been buffered + + var remainder = await _innerStream.ReadLineAsync(lengthLimit: 100, cancellationToken: cancellationToken).ConfigureAwait(false); // Whitespace may exceed the buffer. + remainder = remainder.Trim(); + if (string.Equals("--", remainder, StringComparison.Ordinal)) + { + FinalBoundaryFound = true; + } +#pragma warning disable CA1820 // Test for empty strings using string length + Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder); +#pragma warning restore CA1820 // Test for empty strings using string length + + _finished = true; + return 0; + } + + // No possible boundary match within the buffered data, return the data from the buffer. + read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count)); + return UpdatePosition(read); + } + + // Does segment1 contain all of matchBytes, or does it end with the start of matchBytes? + // 1: AAAAABBBBBCCCCC + // 2: BBBBB + // Or: + // 1: AAAAABBB + // 2: BBBBB + private bool SubMatch(ArraySegment segment1, byte[] matchBytes, out int matchOffset, out int matchCount) + { + // clear matchCount to zero + matchCount = 0; + + // case 1: does segment1 fully contain matchBytes? + { + var matchBytesLengthMinusOne = matchBytes.Length - 1; + var matchBytesLastByte = matchBytes[matchBytesLengthMinusOne]; + var segmentEndMinusMatchBytesLength = segment1.Offset + segment1.Count - matchBytes.Length; + + matchOffset = segment1.Offset; + while (matchOffset < segmentEndMinusMatchBytesLength) + { + var lookaheadTailChar = segment1.Array![matchOffset + matchBytesLengthMinusOne]; + if (lookaheadTailChar == matchBytesLastByte && + CompareBuffers(segment1.Array, matchOffset, matchBytes, 0, matchBytesLengthMinusOne) == 0) + { + matchCount = matchBytes.Length; + return true; + } + matchOffset += _boundary.GetSkipValue(lookaheadTailChar); + } + } + + // case 2: does segment1 end with the start of matchBytes? + var segmentEnd = segment1.Offset + segment1.Count; + + matchCount = 0; + for (; matchOffset < segmentEnd; matchOffset++) + { + var countLimit = segmentEnd - matchOffset; + for (matchCount = 0; matchCount < matchBytes.Length && matchCount < countLimit; matchCount++) + { + if (matchBytes[matchCount] != segment1.Array![matchOffset + matchCount]) + { + matchCount = 0; + break; + } + } + if (matchCount > 0) + { + break; + } + } + return matchCount > 0; + } + + private static int CompareBuffers(byte[] buffer1, int offset1, byte[] buffer2, int offset2, int count) + { + for (; count-- > 0; offset1++, offset2++) + { + if (buffer1[offset1] != buffer2[offset2]) + { + return buffer1[offset1] - buffer2[offset2]; + } + } + return 0; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs new file mode 100644 index 000000000000..5849be42c018 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Azure.Core +{ + internal class MultipartSection + { + public Dictionary Headers { get; set; } + + /// + /// Gets or sets the body. + /// + public Stream Body { get; set; } = default!; + + /// + /// The position where the body starts in the total multipart body. + /// This may not be available if the total multipart body is not seekable. + /// + public long? BaseStreamOffset { get; set; } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs b/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs new file mode 100644 index 000000000000..6be8b9cf138c --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src + +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable IDE1006 // Prefix _ unexpected + +namespace Azure.Core +{ + internal static class StreamHelperExtensions + { + private const int _maxReadBufferSize = 1024 * 4; + + public static Task DrainAsync(this Stream stream, CancellationToken cancellationToken) + { + return stream.DrainAsync(ArrayPool.Shared, null, cancellationToken); + } + + public static Task DrainAsync(this Stream stream, long? limit, CancellationToken cancellationToken) + { + return stream.DrainAsync(ArrayPool.Shared, limit, cancellationToken); + } + + public static async Task DrainAsync(this Stream stream, ArrayPool bytePool, long? limit, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + var buffer = bytePool.Rent(_maxReadBufferSize); + long total = 0; + try + { + var read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + while (read > 0) + { + // Not all streams support cancellation directly. + cancellationToken.ThrowIfCancellationRequested(); + if (limit.HasValue && limit.GetValueOrDefault() - total < read) + { + throw new InvalidDataException($"The stream exceeded the data limit {limit.GetValueOrDefault()}."); + } + total += read; + read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + } + } + finally + { + bytePool.Return(buffer); + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs b/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs new file mode 100644 index 000000000000..5ab190bd6846 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +#nullable disable + +namespace Azure.Core +{ + /// + /// Provides a container for content encoded using multipart/form-data MIME type. + /// + internal class RequestRequestContent : RequestContent + { + #region Fields + + private const string SP = " "; + private const string ColonSP = ": "; + private const string CRLF = "\r\n"; + private const int DefaultHeaderAllocation = 2 * 1024; + + private readonly Request _request; + + #endregion Fields + + #region Construction + + /// + /// Initializes a new instance of the class. + /// + /// The instance to encapsulate. + public RequestRequestContent(Request request) + { + Argument.AssertNotNull(request, nameof(request)); + this._request = request; + } + + #endregion Construction + + #region Dispose + + /// + /// Frees resources held by the object. + /// + public override void Dispose() + { + _request.Dispose(); + } + + #endregion Dispose + + #region Serialization + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellationToken) + { + Argument.AssertNotNull(stream, nameof(stream)); + + byte[] header = SerializeHeader(); + await stream.WriteAsync(header, 0, header.Length).ConfigureAwait(false); + + if (_request.Content != null) + { + await _request.Content.WriteToAsync(stream, cancellationToken).ConfigureAwait(false); + } + } + + public override void WriteTo(Stream stream, CancellationToken cancellationToken) + { + Argument.AssertNotNull(stream, nameof(stream)); + + byte[] header = SerializeHeader(); + stream.Write(header, 0, header.Length); + + if (_request.Content != null) + { + _request.Content.WriteTo(stream, cancellationToken); + } + } + + /// + /// Computes the length of the stream if possible. + /// + /// The computed length of the stream. + /// true if the length has been computed; otherwise false. + public override bool TryComputeLength(out long length) + { + // We have four states we could be in: + // 1. We have content and the stream came back as a null or non-seekable + // 2. We have content and the stream is seekable, so we know its length + // 3. We don't have content + // + // For #1, we return false. + // For #2, we return true & the size of our headers + the content length + // For #3, we return true & the size of our headers + + bool hasContent = _request.Content != null; + length = 0; + + // Cases #1, #2, #3 + if (hasContent) + { + if (!_request!.Content!.TryComputeLength(out length)) + { + length = 0; + return false; + } + } + + // We serialize header to a StringBuilder so that we can determine the length + // following the pattern for HttpContent to try and determine the message length. + // The perf overhead is no larger than for the other HttpContent implementations. + byte[] header = SerializeHeader(); + length += header.Length; + return true; + } + + private byte[] SerializeHeader() + { + StringBuilder message = new StringBuilder(DefaultHeaderAllocation); + + SerializeRequestLine(message, _request); + SerializeHeaderFields(message, _request.Headers); + message.Append(CRLF); + return Encoding.UTF8.GetBytes(message.ToString()); + } + + /// + /// Serializes the HTTP request line. + /// + /// Where to write the request line. + /// The HTTP request. + private static void SerializeRequestLine(StringBuilder message, Request request) + { + Argument.AssertNotNull(message, nameof(message)); + message.Append(request.Method + SP); + message.Append(request.Uri.ToString() + SP); + message.Append("HTTP/1.1" + CRLF); + + // Only insert host header if not already present. + if (!request.Headers.TryGetValue("Host", out _)) + { + message.Append("Host" + ColonSP + request.Uri.Host + CRLF); + } + } + + /// + /// Serializes the header fields. + /// + /// Where to write the status line. + /// The headers to write. + private static void SerializeHeaderFields(StringBuilder message, RequestHeaders headers) + { + Argument.AssertNotNull(message, nameof(message)); + foreach (HttpHeader header in headers) + { + message.Append(header.Name + ColonSP + header.Value + CRLF); + } + } + + #endregion Serialization + } +} diff --git a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj index a7a10b64fb89..7a6449b58fb6 100644 --- a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj +++ b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj @@ -2,6 +2,7 @@ {84491222-6C36-4FA7-BBAE-1FA804129151} $(RequiredTargetFrameworks) + $(DefineConstants);HAS_INTERNALS_VISIBLE_CORE true @@ -22,13 +23,17 @@ + + + - - + + diff --git a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs index 92f490faa4e3..a80324172f01 100644 --- a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs +++ b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs @@ -20,11 +20,11 @@ namespace Azure.Core.Tests [TestFixture(typeof(HttpWebRequestTransport), true)] [TestFixture(typeof(HttpWebRequestTransport), false)] #endif - public class HttpPipelineFunctionalTests: PipelineTestBase + public class HttpPipelineFunctionalTests : PipelineTestBase { private readonly Type _transportType; - public HttpPipelineFunctionalTests(Type transportType, bool isAsync): base(isAsync) + public HttpPipelineFunctionalTests(Type transportType, bool isAsync) : base(isAsync) { _transportType = transportType; } @@ -32,7 +32,7 @@ public HttpPipelineFunctionalTests(Type transportType, bool isAsync): base(isAsy private TestOptions GetOptions() { var options = new TestOptions(); - options.Transport = (HttpPipelineTransport) Activator.CreateInstance(_transportType); + options.Transport = (HttpPipelineTransport)Activator.CreateInstance(_transportType); return options; } diff --git a/sdk/core/Azure.Core/tests/MultipartReaderTests.cs b/sdk/core/Azure.Core/tests/MultipartReaderTests.cs new file mode 100644 index 000000000000..a5386f68e348 --- /dev/null +++ b/sdk/core/Azure.Core/tests/MultipartReaderTests.cs @@ -0,0 +1,385 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable warnings + +using System; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class MultipartReaderTests + { + private const string Boundary = "9051914041544843365972754266"; + // Note that CRLF (\r\n) is required. You can't use multi-line C# strings here because the line breaks on Linux are just LF. + private const string OnePartBody = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--\r\n"; + private const string OnePartBodyTwoHeaders = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"Custom-header: custom-value\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--\r\n"; + private const string OnePartBodyWithTrailingWhitespace = +"--9051914041544843365972754266 \r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--\r\n"; + // It's non-compliant but common to leave off the last CRLF. + private const string OnePartBodyWithoutFinalCRLF = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--"; + private const string TwoPartBody = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" + +"Content-Type: text/plain\r\n" + +"\r\n" + +"Content of a.txt.\r\n" + +"\r\n" + +"--9051914041544843365972754266--\r\n"; + private const string TwoPartBodyWithUnicodeFileName = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"file1\"; filename=\"a色.txt\"\r\n" + +"Content-Type: text/plain\r\n" + +"\r\n" + +"Content of a.txt.\r\n" + +"\r\n" + +"--9051914041544843365972754266--\r\n"; + private const string ThreePartBody = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" + +"Content-Type: text/plain\r\n" + +"\r\n" + +"Content of a.txt.\r\n" + +"\r\n" + +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r\n" + +"Content-Type: text/html\r\n" + +"\r\n" + +"Content of a.html.\r\n" + +"\r\n" + +"--9051914041544843365972754266--\r\n"; + + private const string TwoPartBodyIncompleteBuffer = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" + +"Content-Type: text/plain\r\n" + +"\r\n" + +"Content of a.txt.\r\n" + +"\r\n" + +"--9051914041544843365"; + + private static MemoryStream MakeStream(string text) + { + return new MemoryStream(Encoding.UTF8.GetBytes(text)); + } + + private static string GetString(byte[] buffer, int count) + { + return Encoding.ASCII.GetString(buffer, 0, count); + } + + [Test] + public async Task MultipartReader_ReadSinglePartBody_Success() + { + var stream = MakeStream(OnePartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public void MultipartReader_HeaderCountExceeded_Throws() + { + var stream = MakeStream(OnePartBodyTwoHeaders); + var reader = new MultipartReader(Boundary, stream) + { + HeadersCountLimit = 1, + }; + + var exception = Assert.ThrowsAsync(async () => await reader.ReadNextSectionAsync()); + Assert.That(exception.Message, Is.EqualTo("Multipart headers count limit 1 exceeded.")); + } + + [Test] + public void MultipartReader_HeadersLengthExceeded_Throws() + { + var stream = MakeStream(OnePartBodyTwoHeaders); + var reader = new MultipartReader(Boundary, stream) + { + HeadersLengthLimit = 60, + }; + + var exception = Assert.ThrowsAsync(() => reader.ReadNextSectionAsync()); + Assert.That(exception.Message, Is.EqualTo("Line length limit 17 exceeded.")); + } + + [Test] + public async Task MultipartReader_ReadSinglePartBodyWithTrailingWhitespace_Success() + { + var stream = MakeStream(OnePartBodyWithTrailingWhitespace); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public async Task MultipartReader_ReadSinglePartBodyWithoutLastCRLF_Success() + { + var stream = MakeStream(OnePartBodyWithoutFinalCRLF); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public async Task MultipartReader_ReadTwoPartBody_Success() + { + var stream = MakeStream(TwoPartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(2)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\"")); + Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain")); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public async Task MultipartReader_ReadTwoPartBodyWithUnicodeFileName_Success() + { + var stream = MakeStream(TwoPartBodyWithUnicodeFileName); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(2)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a色.txt\"")); + Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain")); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public async Task MultipartReader_ThreePartBody_Success() + { + var stream = MakeStream(ThreePartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(2)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\"")); + Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain")); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n")); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(2)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file2\"; filename=\"a.html\"")); + Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/html")); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.html.\r\n")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public void MultipartReader_BufferSizeMustBeLargerThanBoundary_Throws() + { + var stream = MakeStream(ThreePartBody); + Assert.Throws(() => + { + var reader = new MultipartReader(Boundary, stream, 5); + }); + } + + [Test] + public async Task MultipartReader_TwoPartBodyIncompleteBuffer_TwoSectionsReadSuccessfullyThirdSectionThrows() + { + var stream = MakeStream(TwoPartBodyIncompleteBuffer); + var reader = new MultipartReader(Boundary, stream); + var buffer = new byte[128]; + + //first section can be read successfully + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\"")); + var read = section.Body.Read(buffer, 0, buffer.Length); + Assert.That(GetString(buffer, read), Is.EqualTo("text default")); + + //second section can be read successfully (even though the bottom boundary is truncated) + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(2)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\"")); + Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain")); + read = section.Body.Read(buffer, 0, buffer.Length); + Assert.That(GetString(buffer, read), Is.EqualTo("Content of a.txt.\r\n")); + + Assert.ThrowsAsync(async () => + { + // we'll be unable to ensure enough bytes are buffered to even contain a final boundary + section = await reader.ReadNextSectionAsync(); + }); + } + + [Test] + public async Task MultipartReader_ReadInvalidUtf8Header_ReplacementCharacters() + { + var body1 = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\" filename=\"a"; + + var body2 = +".txt\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--\r\n"; + var stream = new MemoryStream(); + var bytes = Encoding.UTF8.GetBytes(body1); + stream.Write(bytes, 0, bytes.Length); + + // Write an invalid utf-8 segment in the middle + stream.Write(new byte[] { 0xC1, 0x21 }, 0, 2); + + bytes = Encoding.UTF8.GetBytes(body2); + stream.Write(bytes, 0, bytes.Length); + stream.Seek(0, SeekOrigin.Begin); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\" filename=\"a\uFFFD!.txt\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Test] + public async Task MultipartReader_ReadInvalidUtf8SurrogateHeader_ReplacementCharacters() + { + var body1 = +"--9051914041544843365972754266\r\n" + +"Content-Disposition: form-data; name=\"text\" filename=\"a"; + + var body2 = +".txt\"\r\n" + +"\r\n" + +"text default\r\n" + +"--9051914041544843365972754266--\r\n"; + var stream = new MemoryStream(); + var bytes = Encoding.UTF8.GetBytes(body1); + stream.Write(bytes, 0, bytes.Length); + + // Write an invalid utf-8 segment in the middle + stream.Write(new byte[] { 0xED, 0xA0, 85 }, 0, 3); + + bytes = Encoding.UTF8.GetBytes(body2); + stream.Write(bytes, 0, bytes.Length); + stream.Seek(0, SeekOrigin.Begin); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.That(section.Headers.Count, Is.EqualTo(1)); + Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\" filename=\"a\uFFFDU.txt\"")); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default")); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + } +} diff --git a/sdk/core/Azure.Core/tests/MultipartTests.cs b/sdk/core/Azure.Core/tests/MultipartTests.cs new file mode 100644 index 000000000000..a283bec4469c --- /dev/null +++ b/sdk/core/Azure.Core/tests/MultipartTests.cs @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable warnings + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class MultipartTests + { + private const string Boundary = "batchresponse_6040fee7-a2b8-4e78-a674-02086369606a"; + private const string ContentType = "multipart/mixed; boundary=" + Boundary; + private const string Body = "{}"; + + // Note that CRLF (\r\n) is required. You can't use multi-line C# strings here because the line breaks on Linux are just LF. + private const string TablesOdataBatchResponse = +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a\r\n" + +"Content-Type: multipart/mixed; boundary=changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" + +"\r\n" + +"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" + +"Content-Type: application/http\r\n" + +"Content-Transfer-Encoding: binary\r\n" + +"\r\n" + +"HTTP/1.1 201 Created\r\n" + +"DataServiceVersion: 3.0;\r\n" + +"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" + +"X-Content-Type-Options: nosniff\r\n" + +"Cache-Control: no-cache\r\n" + +"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='01')\r\n" + +"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" + +"\r\n" + +"{}\r\n" + +"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" + +"Content-Type: application/http\r\n" + +"Content-Transfer-Encoding: binary\r\n" + +"\r\n" + +"HTTP/1.1 201 Created\r\n" + +"DataServiceVersion: 3.0;\r\n" + +"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" + +"X-Content-Type-Options: nosniff\r\n" + +"Cache-Control: no-cache\r\n" + +"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='02')\r\n" + +"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" + +"\r\n" + +"{}\r\n" + +"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" + +"Content-Type: application/http\r\n" + +"Content-Transfer-Encoding: binary\r\n" + +"\r\n" + +"HTTP/1.1 201 Created\r\n" + +"DataServiceVersion: 3.0;\r\n" + +"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" + +"X-Content-Type-Options: nosniff\r\n" + +"Cache-Control: no-cache\r\n" + +"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='03')\r\n" + +"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" + +"\r\n" + +"{}\r\n" + +"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92--\r\n" + +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a--\r\n" + +""; + + private const string BlobBatchResponse = +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" + +"Content-Type: application/http \r\n" + +"Content-ID: 0 \r\n" + +"\r\n" + +"HTTP/1.1 202 Accepted \r\n" + +"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f \r\n" + +"x-ms-version: 2018-11-09 \r\n" + +"\r\n" + +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" + +"Content-Type: application/http \r\n" + +"Content-ID: 1 \r\n" + +"\r\n" + +"HTTP/1.1 202 Accepted \r\n" + +"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2851 \r\n" + +"x-ms-version: 2018-11-09 \r\n" + +"\r\n" + +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" + +"Content-Type: application/http \r\n" + +"Content-ID: 2 \r\n" + +"\r\n" + +"HTTP/1.1 404 The specified blob does not exist. \r\n" + +"x-ms-error-code: BlobNotFound \r\n" + +"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852 \r\n" + +"x-ms-version: 2018-11-09 \r\n" + +"Content-Length: 216 \r\n" + +"Content-Type: application/xml \r\n" + +"\r\n" + +" \r\n" + +"BlobNotFoundThe specified blob does not exist. \r\n" + +"RequestId:778fdc83-801e-0000-62ff-0334671e2852 \r\n" + +"Time:2018-06-14T16:46:54.6040685Z \r\n" + +"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a-- \r\n" + +"0"; + + + [Test] + public async Task ParseBatchChangesetResponse() + { + var stream = MakeStream(TablesOdataBatchResponse); + var responses = await Multipart.ParseAsync(stream, ContentType, true, default); + + Assert.That(responses, Is.Not.Null); + Assert.That(responses.Length, Is.EqualTo(3)); + Assert.That(responses.All(r => r.Status == (int)HttpStatusCode.Created)); + + foreach (var response in responses) + { + Assert.That(response.TryGetHeader("DataServiceVersion", out var version)); + Assert.That(version, Is.EqualTo("3.0;")); + + Assert.That(response.TryGetHeader("Content-Type", out var contentType)); + Assert.That(contentType, Is.EqualTo("application/json;odata=fullmetadata;streaming=true;charset=utf-8")); + + var bytes = new byte[response.ContentStream.Length]; + await response.ContentStream.ReadAsync(bytes, 0, bytes.Length); + var content = GetString(bytes, bytes.Length); + + Assert.That(content, Is.EqualTo(Body)); + + } + } + + [Test] + public async Task ParseBatchResponse() + { + var stream = MakeStream(BlobBatchResponse); + var responses = await Multipart.ParseAsync(stream, ContentType, true, default); + + Assert.That(responses, Is.Not.Null); + Assert.That(responses.Length, Is.EqualTo(3)); + + var response = responses[0]; + Assert.That(response.Status, Is.EqualTo( (int)HttpStatusCode.Accepted)); + Assert.That(response.TryGetHeader("x-ms-version", out var version)); + Assert.That(version, Is.EqualTo("2018-11-09")); + Assert.That(response.TryGetHeader("x-ms-request-id", out _)); + + response = responses[1]; + Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.Accepted)); + Assert.That(response.TryGetHeader("x-ms-version", out version)); + Assert.That(version, Is.EqualTo("2018-11-09")); + Assert.That(response.TryGetHeader("x-ms-request-id", out _)); + + response = responses[2]; + Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.NotFound)); + Assert.That(response.TryGetHeader("x-ms-version", out version)); + Assert.That(version, Is.EqualTo("2018-11-09")); + Assert.That(response.TryGetHeader("x-ms-request-id", out _)); + var bytes = new byte[response.ContentStream.Length]; + await response.ContentStream.ReadAsync(bytes, 0, bytes.Length); + var content = GetString(bytes, bytes.Length); + Assert.That(content.Contains("BlobNotFoundThe specified blob does not exist.")); + } + + [Test] + public async Task SendMultipartData() + { + const string ApplicationJson = "application/json"; + const string cteHeaderName = "Content-Transfer-Encoding"; + const string Binary = "binary"; + const string Mixed = "mixed"; + const string ApplicationJsonOdata = "application/json; odata=nometadata"; + const string DataServiceVersion = "DataServiceVersion"; + const string Three0 = "3.0"; + const string Host = "myaccount.table.core.windows.net"; + + using Request request = new MockRequest + { + Method = RequestMethod.Put + }; + request.Uri.Reset(new Uri("https://foo")); + + Guid batchGuid = Guid.NewGuid(); + var content = new MultipartContent(Mixed, $"batch_{batchGuid}"); + content.ApplyToRequest(request); + + Guid changesetGuid = Guid.NewGuid(); + var changeset = new MultipartContent(Mixed, $"changeset_{changesetGuid}"); + content.Add(changeset, changeset._headers); + + var postReq1 = new MockRequest + { + Method = RequestMethod.Post + }; + string postUri = $"https://{Host}/Blogs"; + postReq1.Uri.Reset(new Uri(postUri)); + postReq1.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata); + postReq1.Headers.Add(HttpHeader.Names.Accept, ApplicationJson); + postReq1.Headers.Add(DataServiceVersion, Three0); + const string post1Body = "{ \"PartitionKey\":\"Channel_19\", \"RowKey\":\"1\", \"Rating\":9, \"Text\":\"Azure...\"}"; + postReq1.Content = RequestContent.Create(Encoding.UTF8.GetBytes(post1Body)); + changeset.Add(new RequestRequestContent(postReq1), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } }); + + var postReq2 = new MockRequest + { + Method = RequestMethod.Post + }; + postReq2.Uri.Reset(new Uri(postUri)); + postReq2.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata); + postReq2.Headers.Add(HttpHeader.Names.Accept, ApplicationJson); + postReq2.Headers.Add(DataServiceVersion, Three0); + const string post2Body = "{ \"PartitionKey\":\"Channel_17\", \"RowKey\":\"2\", \"Rating\":9, \"Text\":\"Azure...\"}"; + postReq2.Content = RequestContent.Create(Encoding.UTF8.GetBytes(post2Body)); + changeset.Add(new RequestRequestContent(postReq2), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } }); + + var patchReq = new MockRequest + { + Method = RequestMethod.Patch + }; + string mergeUri = $"https://{Host}/Blogs(PartitionKey='Channel_17',%20RowKey='3')"; + patchReq.Uri.Reset(new Uri(mergeUri)); + patchReq.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata); + patchReq.Headers.Add(HttpHeader.Names.Accept, ApplicationJson); + patchReq.Headers.Add(DataServiceVersion, Three0); + const string patchBody = "{ \"PartitionKey\":\"Channel_19\", \"RowKey\":\"3\", \"Rating\":9, \"Text\":\"Azure Tables...\"}"; + patchReq.Content = RequestContent.Create(Encoding.UTF8.GetBytes(patchBody)); + changeset.Add(new RequestRequestContent(patchReq), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } }); + + request.Content = content; + var memStream = new MemoryStream(); + await content.WriteToAsync(memStream, default); + memStream.Position = 0; + using var sr = new StreamReader(memStream, Encoding.UTF8); + string requestBody = sr.ReadToEnd(); + Console.WriteLine(requestBody); + + + Assert.That(requestBody, Is.EqualTo($"--batch_{batchGuid}\r\n" + + $"{HttpHeader.Names.ContentType}: multipart/mixed; boundary=changeset_{changesetGuid}\r\n" + + $"\r\n" + + $"--changeset_{changesetGuid}\r\n" + + $"{HttpHeader.Names.ContentType}: application/http\r\n" + + $"{cteHeaderName}: {Binary}\r\n" + + $"\r\n" + + $"POST {postUri} HTTP/1.1\r\n" + + $"{HttpHeader.Names.Host}: {Host}\r\n" + + $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" + + $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" + + $"{DataServiceVersion}: {Three0}\r\n" + + $"{HttpHeader.Names.ContentLength}: 75\r\n" + + $"\r\n" + + $"{post1Body}\r\n" + + $"--changeset_{changesetGuid}\r\n" + + $"{HttpHeader.Names.ContentType}: application/http\r\n" + + $"{cteHeaderName}: {Binary}\r\n" + + $"\r\n" + + $"POST {postUri} HTTP/1.1\r\n" + + $"{HttpHeader.Names.Host}: {Host}\r\n" + + $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" + + $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" + + $"{DataServiceVersion}: {Three0}\r\n" + + $"{HttpHeader.Names.ContentLength}: 75\r\n" + + $"\r\n" + + $"{post2Body}\r\n" + + $"--changeset_{changesetGuid}\r\n" + + $"{HttpHeader.Names.ContentType}: application/http\r\n" + + $"{cteHeaderName}: {Binary}\r\n" + + $"\r\n" + + $"PATCH {mergeUri} HTTP/1.1\r\n" + + $"{HttpHeader.Names.Host}: {Host}\r\n" + + $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" + + $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" + + $"{DataServiceVersion}: {Three0}\r\n" + + $"{HttpHeader.Names.ContentLength}: 82\r\n" + + $"\r\n" + + $"{patchBody}\r\n" + + $"--changeset_{changesetGuid}--\r\n" + + $"\r\n" + + $"--batch_{batchGuid}--\r\n" + + $"")); + } + + private static MemoryStream MakeStream(string text) + { + return new MemoryStream(Encoding.UTF8.GetBytes(text)); + } + + private static string GetString(byte[] buffer, int count) + { + return Encoding.ASCII.GetString(buffer, 0, count); + } + } +} diff --git a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs index 1b8b8c51dc49..8e3cdc2e01e3 100644 --- a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs +++ b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs @@ -237,7 +237,7 @@ public async Task CanGetAndSetMethod(RequestMethod method, string expectedMethod [TestCaseSource(nameof(AllHeadersWithValuesAndType))] public async Task CanGetAndAddRequestHeaders(string headerName, string headerValue, bool contentHeader) { - StringValues httpHeaderValues; + StringValues httpHeaderValues; using TestServer testServer = new TestServer( context => diff --git a/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj b/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj index e9e6e1224089..44dbd456d22a 100644 --- a/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj +++ b/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj @@ -8,15 +8,19 @@ $(RequiredTargetFrameworks) $(NoWarn);CA1812;CS1591 + true + + + @@ -25,11 +29,12 @@ + + - diff --git a/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs b/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs new file mode 100644 index 000000000000..7db3698c60cd --- /dev/null +++ b/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using Azure.Core; + +namespace Azure.Data.Tables +{ + internal static class MultipartContentExtensions + { + internal static MultipartContent AddChangeset(this MultipartContent batch) + { + var changeset = new MultipartContent("mixed", $"changeset_{Guid.NewGuid()}"); + batch.Add(changeset, changeset._headers); + return changeset; + } + } +} diff --git a/sdk/tables/Azure.Data.Tables/src/TableClient.cs b/sdk/tables/Azure.Data.Tables/src/TableClient.cs index 269a5940d071..042185b0d079 100644 --- a/sdk/tables/Azure.Data.Tables/src/TableClient.cs +++ b/sdk/tables/Azure.Data.Tables/src/TableClient.cs @@ -903,6 +903,76 @@ public virtual Response SetAccessPolicy(IEnumerable tableAcl, } } + /// + /// Placeholder for batch operations. This is just being used for testing. + /// + /// + /// + /// + /// + internal virtual async Task>> BatchTestAsync(IEnumerable entities, CancellationToken cancellationToken = default) where T : class, ITableEntity, new() + { + using DiagnosticScope scope = _diagnostics.CreateScope($"{nameof(TableClient)}.{nameof(BatchTest)}"); + scope.Start(); + try + { + var batch = TableRestClient.CreateBatchContent(); + var changeset = batch.AddChangeset(); + foreach (var entity in entities) + { + _tableOperations.AddInsertEntityRequest( + changeset, + _table, + null, + null, + null, + tableEntityProperties: entity.ToOdataAnnotatedDictionary(), + queryOptions: new QueryOptions() { Format = _format }); + } + return await _tableOperations.SendBatchRequestAsync(_tableOperations.CreateBatchRequest(batch, null, null), cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + scope.Failed(ex); + throw; + } + } + + /// + /// Placeholder for batch operations. This is just being used for testing. + /// + /// + /// + /// + /// + internal virtual Response> BatchTest(IEnumerable entities, CancellationToken cancellationToken = default) where T : class, ITableEntity, new() + { + using DiagnosticScope scope = _diagnostics.CreateScope($"{nameof(TableClient)}.{nameof(BatchTest)}"); + scope.Start(); + try + { + var batch = TableRestClient.CreateBatchContent(); + var changeset = batch.AddChangeset(); + foreach (var entity in entities) + { + _tableOperations.AddInsertEntityRequest( + changeset, + _table, + null, + null, + null, + tableEntityProperties: entity.ToOdataAnnotatedDictionary(), + queryOptions: new QueryOptions() { Format = _format }); + } + return _tableOperations.SendBatchRequest(_tableOperations.CreateBatchRequest(batch, null, null), cancellationToken); + } + catch (Exception ex) + { + scope.Failed(ex); + throw; + } + } + /// /// Creates an Odata filter query string from the provided expression. /// diff --git a/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs b/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs new file mode 100644 index 000000000000..a484f1abcc47 --- /dev/null +++ b/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; +using Azure.Data.Tables.Models; + +namespace Azure.Data.Tables +{ + internal partial class TableRestClient + { + private const string CteHeaderName = "Content-Transfer-Encoding"; + private const string Binary = "binary"; + private const string ApplicationHttp = "application/http"; + + internal HttpMessage CreateBatchRequest(MultipartContent content, string requestId, ResponseFormat? responsePreference) + { + var message = _pipeline.CreateMessage(); + var request = message.Request; + request.Method = RequestMethod.Post; + var uri = new RawRequestUriBuilder(); + uri.AppendRaw(url, false); + uri.AppendPath("/$batch", false); + request.Uri = uri; + request.Headers.Add("x-ms-version", version); + if (requestId != null) + { + request.Headers.Add("x-ms-client-request-id", requestId); + } + request.Headers.Add("DataServiceVersion", "3.0"); + if (responsePreference != null) + { + request.Headers.Add("Prefer", responsePreference.Value.ToString()); + } + //request.Headers.Add("Accept", "application/json"); + request.Content = content; + content.ApplyToRequest(request); + return message; + } + + internal static MultipartContent CreateBatchContent() + { + return new MultipartContent("mixed", $"batch_{Guid.NewGuid()}"); + } + + internal void AddInsertEntityRequest(MultipartContent changeset, string table, int? timeout, string requestId, ResponseFormat? responsePreference, IDictionary tableEntityProperties, QueryOptions queryOptions) + { + var message = CreateInsertEntityRequest(table, timeout, requestId, responsePreference, tableEntityProperties, queryOptions); + changeset.Add(new RequestRequestContent(message.Request), new Dictionary { { HttpHeader.Names.ContentType, ApplicationHttp }, { CteHeaderName, Binary } }); + } + + /// Insert entity in a table. + /// + /// The cancellation token to use. + /// is null. + public async Task>> SendBatchRequestAsync(HttpMessage message, CancellationToken cancellationToken = default) + { + if (message == null) + { + throw new ArgumentNullException(nameof(message)); + } + + await _pipeline.SendAsync(message, cancellationToken).ConfigureAwait(false); + switch (message.Response.Status) + { + case 202: + { + var responses = await Multipart.ParseAsync( + message.Response.ContentStream, + message.Response.Headers.ContentType, + true, + cancellationToken).ConfigureAwait(false); + + return Response.FromValue(responses.ToList(), message.Response); + } + default: + throw await _clientDiagnostics.CreateRequestFailedExceptionAsync(message.Response).ConfigureAwait(false); + } + } + + /// Insert entity in a table. + /// + /// The cancellation token to use. + /// is null. + public Response> SendBatchRequest(HttpMessage message, CancellationToken cancellationToken = default) + { + if (message == null) + { + throw new ArgumentNullException(nameof(message)); + } + + _pipeline.Send(message, cancellationToken); + switch (message.Response.Status) + { + case 202: + { + var responses = Multipart.ParseAsync( + message.Response.ContentStream, + message.Response.Headers.ContentType, + false, + cancellationToken).EnsureCompleted(); + + return Response.FromValue(responses.ToList(), message.Response); + } + default: + throw _clientDiagnostics.CreateRequestFailedException(message.Response); + } + } + } +} diff --git a/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs b/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs index d90486723ef8..b3a8ed4120aa 100644 --- a/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs +++ b/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs @@ -879,7 +879,6 @@ public async Task GetAccessPoliciesReturnsPolicies() await client.SetAccessPolicyAsync(tableAcl: policyToCreate); - // Get the created policy. var policies = await client.GetAccessPolicyAsync(); @@ -910,5 +909,30 @@ public async Task GetEntityReturnsSingleEntity() Assert.That(entityResults, Is.Not.Null, "The entity should not be null."); } + + /// + /// Validates the functionality of the TableClient. + /// + [Test] + [LiveOnly] + public async Task BatchInsert() + { + if (_endpointType == TableEndpointType.CosmosTable) + { + Assert.Ignore("https://github.com/Azure/azure-sdk-for-net/issues/14272"); + } + var entitiesToCreate = CreateCustomTableEntities(PartitionKeyValue, 20); + + // Create the new entities. + + var responses = await client.BatchTestAsync(entitiesToCreate).ConfigureAwait(false); + + foreach (var response in responses.Value) + { + Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.Created)); + } + Assert.That(responses.Value.Count, Is.EqualTo(entitiesToCreate.Count)); + } } + }