diff --git a/sdk/core/Azure.Core/api/Azure.Core.net461.cs b/sdk/core/Azure.Core/api/Azure.Core.net461.cs
index c0b60287caf0..9fe50a17e2e9 100644
--- a/sdk/core/Azure.Core/api/Azure.Core.net461.cs
+++ b/sdk/core/Azure.Core/api/Azure.Core.net461.cs
@@ -216,15 +216,18 @@ public static partial class Names
{
public static string Accept { get { throw null; } }
public static string Authorization { get { throw null; } }
+ public static string ContentDisposition { get { throw null; } }
public static string ContentLength { get { throw null; } }
public static string ContentType { get { throw null; } }
public static string Date { get { throw null; } }
public static string ETag { get { throw null; } }
+ public static string Host { get { throw null; } }
public static string IfMatch { get { throw null; } }
public static string IfModifiedSince { get { throw null; } }
public static string IfNoneMatch { get { throw null; } }
public static string IfUnmodifiedSince { get { throw null; } }
public static string Range { get { throw null; } }
+ public static string Referer { get { throw null; } }
public static string UserAgent { get { throw null; } }
public static string XMsDate { get { throw null; } }
public static string XMsRange { get { throw null; } }
diff --git a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs
index c0b60287caf0..9fe50a17e2e9 100644
--- a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs
+++ b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs
@@ -216,15 +216,18 @@ public static partial class Names
{
public static string Accept { get { throw null; } }
public static string Authorization { get { throw null; } }
+ public static string ContentDisposition { get { throw null; } }
public static string ContentLength { get { throw null; } }
public static string ContentType { get { throw null; } }
public static string Date { get { throw null; } }
public static string ETag { get { throw null; } }
+ public static string Host { get { throw null; } }
public static string IfMatch { get { throw null; } }
public static string IfModifiedSince { get { throw null; } }
public static string IfNoneMatch { get { throw null; } }
public static string IfUnmodifiedSince { get { throw null; } }
public static string Range { get { throw null; } }
+ public static string Referer { get { throw null; } }
public static string UserAgent { get { throw null; } }
public static string XMsDate { get { throw null; } }
public static string XMsRange { get { throw null; } }
diff --git a/sdk/core/Azure.Core/src/HttpHeader.cs b/sdk/core/Azure.Core/src/HttpHeader.cs
index 5f075a98a0d4..cf748bfb3f45 100644
--- a/sdk/core/Azure.Core/src/HttpHeader.cs
+++ b/sdk/core/Azure.Core/src/HttpHeader.cs
@@ -131,9 +131,19 @@ public static class Names
/// Returns. "If-Unmodified-Since"
///
public static string IfUnmodifiedSince => "If-Unmodified-Since";
+ ///
+ /// Returns. "Referer"
+ ///
+ public static string Referer => "Referer";
+ ///
+ /// Returns. "Host"
+ ///
+ public static string Host => "Host";
- internal static string Referer => "Referer";
- internal static string Host => "Host";
+ ///
+ /// Returns "Content-Disposition"
+ ///
+ public static string ContentDisposition => "Content-Disposition";
}
#pragma warning disable CA1034 // Nested types should not be visible
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs b/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs
new file mode 100644
index 000000000000..e2f455765a02
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/BufferedReadStream.cs
@@ -0,0 +1,440 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System;
+using System.Buffers;
+using System.IO;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+#pragma warning disable CA1305 // ToString Locale
+#pragma warning disable CA1822 // Make member static
+#pragma warning disable IDE0016 // Simplify null check
+#pragma warning disable IDE0036 // Modifiers not ordered
+#pragma warning disable IDE0054 // Use compound assignment
+#pragma warning disable IDE0059 // Unnecessary assignment
+
+namespace Azure.Core
+{
+ ///
+ /// A Stream that wraps another stream and allows reading lines.
+ /// The data is buffered in memory.
+ ///
+ internal class BufferedReadStream : Stream
+ {
+ private const byte CR = (byte)'\r';
+ private const byte LF = (byte)'\n';
+
+ private readonly Stream _inner;
+ private readonly byte[] _buffer;
+ private readonly ArrayPool _bytePool;
+ private int _bufferOffset = 0;
+ private int _bufferCount = 0;
+ private bool _disposed;
+
+ ///
+ /// Creates a new stream.
+ ///
+ /// The stream to wrap.
+ /// Size of buffer in bytes.
+ public BufferedReadStream(Stream inner, int bufferSize)
+ : this(inner, bufferSize, ArrayPool.Shared)
+ {
+ }
+
+ ///
+ /// Creates a new stream.
+ ///
+ /// The stream to wrap.
+ /// Size of buffer in bytes.
+ /// ArrayPool for the buffer.
+ public BufferedReadStream(Stream inner, int bufferSize, ArrayPool bytePool)
+ {
+ if (inner == null)
+ {
+ throw new ArgumentNullException(nameof(inner));
+ }
+
+ _inner = inner;
+ _bytePool = bytePool;
+ _buffer = bytePool.Rent(bufferSize);
+ }
+
+ ///
+ /// The currently buffered data.
+ ///
+ public ArraySegment BufferedData
+ {
+ get { return new ArraySegment(_buffer, _bufferOffset, _bufferCount); }
+ }
+
+ ///
+ public override bool CanRead
+ {
+ get { return _inner.CanRead || _bufferCount > 0; }
+ }
+
+ ///
+ public override bool CanSeek
+ {
+ get { return _inner.CanSeek; }
+ }
+
+ ///
+ public override bool CanTimeout
+ {
+ get { return _inner.CanTimeout; }
+ }
+
+ ///
+ public override bool CanWrite
+ {
+ get { return _inner.CanWrite; }
+ }
+
+ ///
+ public override long Length
+ {
+ get { return _inner.Length; }
+ }
+
+ ///
+ public override long Position
+ {
+ get { return _inner.Position - _bufferCount; }
+ set
+ {
+ if (value < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(value), value, "Position must be positive.");
+ }
+ if (value == Position)
+ {
+ return;
+ }
+
+ // Backwards?
+ if (value <= _inner.Position)
+ {
+ // Forward within the buffer?
+ var innerOffset = (int)(_inner.Position - value);
+ if (innerOffset <= _bufferCount)
+ {
+ // Yes, just skip some of the buffered data
+ _bufferOffset += innerOffset;
+ _bufferCount -= innerOffset;
+ }
+ else
+ {
+ // No, reset the buffer
+ _bufferOffset = 0;
+ _bufferCount = 0;
+ _inner.Position = value;
+ }
+ }
+ else
+ {
+ // Forward, reset the buffer
+ _bufferOffset = 0;
+ _bufferCount = 0;
+ _inner.Position = value;
+ }
+ }
+ }
+
+ ///
+ public override long Seek(long offset, SeekOrigin origin)
+ {
+ if (origin == SeekOrigin.Begin)
+ {
+ Position = offset;
+ }
+ else if (origin == SeekOrigin.Current)
+ {
+ Position = Position + offset;
+ }
+ else // if (origin == SeekOrigin.End)
+ {
+ Position = Length + offset;
+ }
+ return Position;
+ }
+
+ ///
+ public override void SetLength(long value)
+ {
+ _inner.SetLength(value);
+ }
+
+ ///
+ protected override void Dispose(bool disposing)
+ {
+ if (!_disposed)
+ {
+ _disposed = true;
+ _bytePool.Return(_buffer);
+
+ if (disposing)
+ {
+ _inner.Dispose();
+ }
+ }
+ }
+
+ ///
+ public override void Flush()
+ {
+ _inner.Flush();
+ }
+
+ ///
+ public override Task FlushAsync(CancellationToken cancellationToken)
+ {
+ return _inner.FlushAsync(cancellationToken);
+ }
+
+ ///
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ _inner.Write(buffer, offset, count);
+ }
+
+ ///
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ return _inner.WriteAsync(buffer, offset, count, cancellationToken);
+ }
+
+ ///
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ ValidateBuffer(buffer, offset, count);
+
+ // Drain buffer
+ if (_bufferCount > 0)
+ {
+ int toCopy = Math.Min(_bufferCount, count);
+ Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy);
+ _bufferOffset += toCopy;
+ _bufferCount -= toCopy;
+ return toCopy;
+ }
+
+ return _inner.Read(buffer, offset, count);
+ }
+
+ ///
+ public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ ValidateBuffer(buffer, offset, count);
+
+ // Drain buffer
+ if (_bufferCount > 0)
+ {
+ int toCopy = Math.Min(_bufferCount, count);
+ Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy);
+ _bufferOffset += toCopy;
+ _bufferCount -= toCopy;
+ return toCopy;
+ }
+
+ return await _inner.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ /// Ensures that the buffer is not empty.
+ ///
+ /// Returns true if the buffer is not empty; false otherwise.
+ public bool EnsureBuffered()
+ {
+ if (_bufferCount > 0)
+ {
+ return true;
+ }
+ // Downshift to make room
+ _bufferOffset = 0;
+ _bufferCount = _inner.Read(_buffer, 0, _buffer.Length);
+ return _bufferCount > 0;
+ }
+
+ ///
+ /// Ensures that the buffer is not empty.
+ ///
+ /// Cancellation token.
+ /// Returns true if the buffer is not empty; false otherwise.
+ public async Task EnsureBufferedAsync(CancellationToken cancellationToken)
+ {
+ if (_bufferCount > 0)
+ {
+ return true;
+ }
+ // Downshift to make room
+ _bufferOffset = 0;
+ _bufferCount = await _inner.ReadAsync(_buffer, 0, _buffer.Length, cancellationToken).ConfigureAwait(false);
+ return _bufferCount > 0;
+ }
+
+ ///
+ /// Ensures that a minimum amount of buffered data is available.
+ ///
+ /// Minimum amount of buffered data.
+ /// Returns true if the minimum amount of buffered data is available; false otherwise.
+ public bool EnsureBuffered(int minCount)
+ {
+ if (minCount > _buffer.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length.ToString());
+ }
+ while (_bufferCount < minCount)
+ {
+ // Downshift to make room
+ if (_bufferOffset > 0)
+ {
+ if (_bufferCount > 0)
+ {
+ Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount);
+ }
+ _bufferOffset = 0;
+ }
+ int read = _inner.Read(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset);
+ _bufferCount += read;
+ if (read == 0)
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ ///
+ /// Ensures that a minimum amount of buffered data is available.
+ ///
+ /// Minimum amount of buffered data.
+ /// Cancellation token.
+ /// Returns true if the minimum amount of buffered data is available; false otherwise.
+ public async Task EnsureBufferedAsync(int minCount, CancellationToken cancellationToken)
+ {
+ if (minCount > _buffer.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length.ToString());
+ }
+ while (_bufferCount < minCount)
+ {
+ // Downshift to make room
+ if (_bufferOffset > 0)
+ {
+ if (_bufferCount > 0)
+ {
+ Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount);
+ }
+ _bufferOffset = 0;
+ }
+ int read = await _inner.ReadAsync(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset, cancellationToken).ConfigureAwait(false);
+ _bufferCount += read;
+ if (read == 0)
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ ///
+ /// Reads a line. A line is defined as a sequence of characters followed by
+ /// a carriage return immediately followed by a line feed. The resulting string does not
+ /// contain the terminating carriage return and line feed.
+ ///
+ /// Maximum allowed line length.
+ /// A line.
+ public string ReadLine(int lengthLimit)
+ {
+ CheckDisposed();
+ using (var builder = new MemoryStream(200))
+ {
+ bool foundCR = false, foundCRLF = false;
+
+ while (!foundCRLF && EnsureBuffered())
+ {
+ if (builder.Length > lengthLimit)
+ {
+ throw new InvalidDataException($"Line length limit {lengthLimit} exceeded.");
+ }
+ ProcessLineChar(builder, ref foundCR, ref foundCRLF);
+ }
+
+ return DecodeLine(builder, foundCRLF);
+ }
+ }
+
+ ///
+ /// Reads a line. A line is defined as a sequence of characters followed by
+ /// a carriage return immediately followed by a line feed. The resulting string does not
+ /// contain the terminating carriage return and line feed.
+ ///
+ /// Maximum allowed line length.
+ /// Cancellation token.
+ /// A line.
+ public async Task ReadLineAsync(int lengthLimit, CancellationToken cancellationToken)
+ {
+ CheckDisposed();
+ using (var builder = new MemoryStream(200))
+ {
+ bool foundCR = false, foundCRLF = false;
+
+ while (!foundCRLF && await EnsureBufferedAsync(cancellationToken).ConfigureAwait(false))
+ {
+ if (builder.Length > lengthLimit)
+ {
+ throw new InvalidDataException($"Line length limit {lengthLimit} exceeded.");
+ }
+
+ ProcessLineChar(builder, ref foundCR, ref foundCRLF);
+ }
+
+ return DecodeLine(builder, foundCRLF);
+ }
+ }
+
+ private void ProcessLineChar(MemoryStream builder, ref bool foundCR, ref bool foundCRLF)
+ {
+ var b = _buffer[_bufferOffset];
+ builder.WriteByte(b);
+ _bufferOffset++;
+ _bufferCount--;
+ if (b == LF && foundCR)
+ {
+ foundCRLF = true;
+ return;
+ }
+ foundCR = b == CR;
+ }
+
+ private string DecodeLine(MemoryStream builder, bool foundCRLF)
+ {
+ // Drop the final CRLF, if any
+ var length = foundCRLF ? builder.Length - 2 : builder.Length;
+ return Encoding.UTF8.GetString(builder.ToArray(), 0, (int)length);
+ }
+
+ private void CheckDisposed()
+ {
+ if (_disposed)
+ {
+ throw new ObjectDisposedException(nameof(BufferedReadStream));
+ }
+ }
+
+ private void ValidateBuffer(byte[] buffer, int offset, int count)
+ {
+ // Delegate most of our validation.
+ var ignored = new ArraySegment(buffer, offset, count);
+ if (count == 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(count), "The value must be greater than zero.");
+ }
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs b/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs
new file mode 100644
index 000000000000..2db059026d77
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/KeyValueAccumulator.cs
@@ -0,0 +1,90 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System;
+using System.Collections.Generic;
+
+#pragma warning disable IDE0008 // Use explicit type
+#pragma warning disable IDE0018 // Inline declaration
+#pragma warning disable IDE0034 // default can be simplified
+
+namespace Azure.Core
+{
+ internal struct KeyValueAccumulator
+ {
+ private Dictionary _accumulator;
+ private Dictionary> _expandingAccumulator;
+
+ public void Append(string key, string value)
+ {
+ if (_accumulator == null)
+ {
+ _accumulator = new Dictionary(StringComparer.OrdinalIgnoreCase);
+ }
+
+ string[] values;
+ if (_accumulator.TryGetValue(key, out values))
+ {
+ if (values.Length == 0)
+ {
+ // Marker entry for this key to indicate entry already in expanding list dictionary
+ _expandingAccumulator[key].Add(value);
+ }
+ else if (values.Length == 1)
+ {
+ // Second value for this key
+ _accumulator[key] = new string[] { values[0], value };
+ }
+ else
+ {
+ // Third value for this key
+ // Add zero count entry and move to data to expanding list dictionary
+ _accumulator[key] = null;
+
+ if (_expandingAccumulator == null)
+ {
+ _expandingAccumulator = new Dictionary>(StringComparer.OrdinalIgnoreCase);
+ }
+
+ // Already 3 entries so use starting allocated as 8; then use List's expansion mechanism for more
+ var list = new List(8);
+
+ list.Add(values[0]);
+ list.Add(values[1]);
+ list.Add(value);
+
+ _expandingAccumulator[key] = list;
+ }
+ }
+ else
+ {
+ // First value for this key
+ _accumulator[key] = new[] { value };
+ }
+
+ ValueCount++;
+ }
+
+ public bool HasValues => ValueCount > 0;
+
+ public int KeyCount => _accumulator?.Count ?? 0;
+
+ public int ValueCount { get; private set; }
+
+ public Dictionary GetResults()
+ {
+ if (_expandingAccumulator != null)
+ {
+ // Coalesce count 3+ multi-value entries into _accumulator dictionary
+ foreach (var entry in _expandingAccumulator)
+ {
+ _accumulator[entry.Key] = entry.Value.ToArray();
+ }
+ }
+
+ return _accumulator ?? new Dictionary(0, StringComparer.OrdinalIgnoreCase);
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs
new file mode 100644
index 000000000000..b63f440ece2e
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MemoryResponse.cs
@@ -0,0 +1,164 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+
+#nullable disable
+
+namespace Azure.Core
+{
+ ///
+ /// A Response that can be constructed in memory without being tied to a
+ /// live request.
+ ///
+ internal class MemoryResponse : Response
+ {
+ private const int NoStatusCode = 0;
+ private const string XmsClientRequestIdName = "x-ms-client-request-id";
+
+ ///
+ /// The Response .
+ ///
+ private int _status = NoStatusCode;
+
+ ///
+ /// The Response .
+ ///
+ private string _reasonPhrase = null;
+
+ ///
+ /// The .
+ ///
+ private readonly IDictionary> _headers =
+ new Dictionary>(StringComparer.OrdinalIgnoreCase);
+
+ ///
+ public override int Status => _status;
+
+ ///
+ public override string ReasonPhrase => _reasonPhrase;
+
+ ///
+ public override Stream ContentStream { get; set; }
+
+ ///
+ public override string ClientRequestId
+ {
+ get => TryGetHeader(XmsClientRequestIdName, out string id) ? id : null;
+ set => SetHeader(XmsClientRequestIdName, value);
+ }
+
+ ///
+ /// Set the Response .
+ ///
+ /// The Response status.
+ public void SetStatus(int status) => _status = status;
+
+ ///
+ /// Set the Response .
+ ///
+ /// The Response ReasonPhrase.
+ public void SetReasonPhrase(string reasonPhrase) => _reasonPhrase = reasonPhrase;
+
+ ///
+ /// Set the Response .
+ ///
+ /// The response content.
+ public void SetContent(byte[] content) => ContentStream = new MemoryStream(content);
+
+ ///
+ /// Set the Response .
+ ///
+ /// The response content.
+ public void SetContent(string content) => SetContent(Encoding.UTF8.GetBytes(content));
+
+ ///
+ /// Dispose the Response.
+ ///
+ public override void Dispose() => ContentStream?.Dispose();
+
+ ///
+ /// Set the value of a response header (and overwrite any existing
+ /// values).
+ ///
+ /// The name of the response header.
+ /// The response header value.
+ public void SetHeader(string name, string value) =>
+ SetHeader(name, new List { value });
+
+ ///
+ /// Set the values of a response header (and overwrite any existing
+ /// values).
+ ///
+ /// The name of the response header.
+ /// The response header values.
+ public void SetHeader(string name, IEnumerable values) =>
+ _headers[name] = values.ToList();
+
+ ///
+ /// Add a response header value.
+ ///
+ /// The name of the response header.
+ /// The response header value.
+ public void AddHeader(string name, string value)
+ {
+ if (!_headers.TryGetValue(name, out List values))
+ {
+ _headers[name] = values = new List();
+ }
+ values.Add(value);
+ }
+
+ ///
+#if HAS_INTERNALS_VISIBLE_CORE
+ internal
+#endif
+ protected override bool ContainsHeader(string name) =>
+ _headers.ContainsKey(name);
+
+ ///
+#if HAS_INTERNALS_VISIBLE_CORE
+ internal
+#endif
+ protected override IEnumerable EnumerateHeaders() =>
+ _headers.Select(header => new HttpHeader(header.Key, JoinHeaderValues(header.Value)));
+
+ ///
+#if HAS_INTERNALS_VISIBLE_CORE
+ internal
+#endif
+ protected override bool TryGetHeader(string name, out string value)
+ {
+ if (_headers.TryGetValue(name, out List headers))
+ {
+ value = JoinHeaderValues(headers);
+ return true;
+ }
+ value = null;
+ return false;
+ }
+
+ ///
+#if HAS_INTERNALS_VISIBLE_CORE
+ internal
+#endif
+ protected override bool TryGetHeaderValues(string name, out IEnumerable values)
+ {
+ bool found = _headers.TryGetValue(name, out List headers);
+ values = headers;
+ return found;
+ }
+
+ ///
+ /// Join multiple header values together with a comma.
+ ///
+ /// The header values.
+ /// A single joined value.
+ private static string JoinHeaderValues(IEnumerable values) =>
+ string.Join(",", values);
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs b/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs
new file mode 100644
index 000000000000..6047d8bb7f13
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/Multipart.cs
@@ -0,0 +1,223 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+#nullable disable
+
+namespace Azure.Core
+{
+ ///
+ /// Provides support for creating and parsing multipart/mixed content.
+ /// This is implementing a couple of layered standards as mentioned at
+ /// https://docs.microsoft.com/en-us/rest/api/storageservices/blob-batch and
+ /// https://docs.microsoft.com/en-us/rest/api/storageservices/performing-entity-group-transactions
+ /// including https://www.odata.org/documentation/odata-version-3-0/batch-processing/
+ /// and https://www.ietf.org/rfc/rfc2046.txt.
+ ///
+ internal static class Multipart
+ {
+ private const int KB = 1024;
+ private const int ResponseLineSize = 4 * KB;
+ private const string MultipartContentTypePrefix = "multipart/mixed; boundary=";
+ private const string ContentIdName = "Content-ID";
+ internal static InvalidOperationException InvalidBatchContentType(string contentType) =>
+ new InvalidOperationException($"Expected {HttpHeader.Names.ContentType} to start with {MultipartContentTypePrefix} but received {contentType}");
+
+ internal static InvalidOperationException InvalidHttpStatusLine(string statusLine) =>
+ new InvalidOperationException($"Expected an HTTP status line, not {statusLine}");
+
+ internal static InvalidOperationException InvalidHttpHeaderLine(string headerLine) =>
+ new InvalidOperationException($"Expected an HTTP header line, not {headerLine}");
+
+ ///
+ /// Parse a multipart/mixed response body into several responses.
+ ///
+ /// The response content.
+ /// The response content type.
+ ///
+ /// Whether to invoke the operation asynchronously.
+ ///
+ ///
+ /// Optional to propagate notifications
+ /// that the operation should be cancelled.
+ ///
+ /// The parsed s.
+ internal static async Task ParseAsync(
+ Stream batchContent,
+ string batchContentType,
+ bool async,
+ CancellationToken cancellationToken)
+ {
+ // Get the batch boundary
+ if (!GetBoundary(batchContentType, out string batchBoundary))
+ {
+ throw InvalidBatchContentType(batchContentType);
+ }
+
+ // Collect the responses in a dictionary (in case the Content-ID
+ // values come back out of order)
+ Dictionary responses = new Dictionary();
+
+ // Collect responses without a Content-ID in a List
+ List responsesWithoutId = new List();
+
+ // Read through the batch body one section at a time until the
+ // reader returns null
+ MultipartReader reader = new MultipartReader(batchBoundary, batchContent);
+ for (MultipartSection section = await reader.GetNextSectionAsync(async, cancellationToken).ConfigureAwait(false);
+ section != null;
+ section = await reader.GetNextSectionAsync(async, cancellationToken).ConfigureAwait(false))
+ {
+ bool contentIdFound = true;
+ if (section.Headers.TryGetValue(HttpHeader.Names.ContentType, out string [] contentTypeValues) &&
+ contentTypeValues.Length == 1 &&
+ GetBoundary(contentTypeValues[0], out string subBoundary))
+ {
+ reader = new MultipartReader(subBoundary, section.Body);
+ continue;
+ }
+ // Get the Content-ID header
+ if (!section.Headers.TryGetValue(ContentIdName, out string [] contentIdValues) ||
+ contentIdValues.Length != 1 ||
+ !int.TryParse(contentIdValues[0], out int contentId))
+ {
+ // If the header wasn't found, this is either:
+ // - a failed request with the details being sent as the first sub-operation
+ // - a tables batch request, which does not utilize Content-ID headers
+ // so we default the Content-ID to 0 and track that no Content-ID was found.
+ contentId = 0;
+ contentIdFound = false;
+ }
+
+ // Build a response
+ MemoryResponse response = new MemoryResponse();
+ if (contentIdFound)
+ {
+ // track responses by Content-ID
+ responses[contentId] = response;
+ }
+ else
+ {
+ // track responses without a Content-ID
+ responsesWithoutId.Add(response);
+ }
+
+ // We're going to read the section's response body line by line
+ using var body = new BufferedReadStream(section.Body, ResponseLineSize);
+
+ // The first line is the status like "HTTP/1.1 202 Accepted"
+ string line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false);
+ string[] status = line.Split(new char[] { ' ' }, 3, StringSplitOptions.RemoveEmptyEntries);
+ if (status.Length != 3)
+ {
+ throw InvalidHttpStatusLine(line);
+ }
+ response.SetStatus(int.Parse(status[1], CultureInfo.InvariantCulture));
+ response.SetReasonPhrase(status[2]);
+
+ // Continue reading headers until we reach a blank line
+ line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false);
+ while (!string.IsNullOrEmpty(line))
+ {
+ // Split the header into the name and value
+ int splitIndex = line.IndexOf(':');
+ if (splitIndex <= 0)
+ {
+ throw InvalidHttpHeaderLine(line);
+ }
+ var name = line.Substring(0, splitIndex);
+ var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim();
+ response.AddHeader(name, value);
+
+ line = await body.ReadLineAsync(async, cancellationToken).ConfigureAwait(false);
+ }
+
+ // Copy the rest of the body as the response content
+ var responseContent = new MemoryStream();
+ if (async)
+ {
+ await body.CopyToAsync(responseContent).ConfigureAwait(false);
+ }
+ else
+ {
+ body.CopyTo(responseContent);
+ }
+ responseContent.Seek(0, SeekOrigin.Begin);
+ response.ContentStream = responseContent;
+ }
+
+ // Collect the responses and order by Content-ID, when available.
+ // Otherwise collect them as we received them.
+ Response[] ordered = new Response[responses.Count + responsesWithoutId.Count];
+ for (int i = 0; i < responses.Count; i++)
+ {
+ ordered[i] = responses[i];
+ }
+ for (int i = responses.Count; i < ordered.Length; i++)
+ {
+ ordered[i] = responsesWithoutId[i - responses.Count];
+ }
+ return ordered;
+ }
+
+
+ ///
+ /// Read the next line of text.
+ ///
+ /// The stream to read from.
+ ///
+ /// Whether to invoke the operation asynchronously.
+ ///
+ ///
+ /// Optional to propagate notifications
+ /// that the operation should be cancelled.
+ ///
+ /// The next line of text.
+ internal static async Task ReadLineAsync(
+ this BufferedReadStream stream,
+ bool async,
+ CancellationToken cancellationToken) =>
+ async ?
+ await stream.ReadLineAsync(ResponseLineSize, cancellationToken).ConfigureAwait(false) :
+ stream.ReadLine(ResponseLineSize);
+
+ ///
+ /// Read the next multipart section.
+ ///
+ /// The reader to parse with.
+ ///
+ /// Whether to invoke the operation asynchronously.
+ ///
+ ///
+ /// Optional to propagate notifications
+ /// that the operation should be cancelled.
+ ///
+ /// The next multipart section.
+ internal static async Task GetNextSectionAsync(
+ this MultipartReader reader,
+ bool async,
+ CancellationToken cancellationToken) =>
+ async ?
+ await reader.ReadNextSectionAsync(cancellationToken).ConfigureAwait(false) :
+#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
+ reader.ReadNextSectionAsync(cancellationToken).GetAwaiter().GetResult(); // #7972: Decide if we need a proper sync API here
+#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead.
+
+ private static bool GetBoundary(string contentType, out string batchBoundary)
+ {
+ if (contentType == null || !contentType.StartsWith(MultipartContentTypePrefix, StringComparison.Ordinal))
+ {
+ batchBoundary = null;
+ return false;
+ }
+ batchBoundary = contentType.Substring(MultipartContentTypePrefix.Length);
+ return true;
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs
new file mode 100644
index 000000000000..bba72d47719b
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartBoundary.cs
@@ -0,0 +1,74 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System;
+using System.Text;
+
+namespace Azure.Core
+{
+ internal class MultipartBoundary
+ {
+ private readonly int[] _skipTable = new int[256];
+ private readonly string _boundary;
+ private bool _expectLeadingCrlf;
+
+ public MultipartBoundary(string boundary, bool expectLeadingCrlf = true)
+ {
+ if (boundary == null)
+ {
+ throw new ArgumentNullException(nameof(boundary));
+ }
+
+ _boundary = boundary;
+ _expectLeadingCrlf = expectLeadingCrlf;
+ Initialize(_boundary, _expectLeadingCrlf);
+ }
+
+ private void Initialize(string boundary, bool expectLeadingCrlf)
+ {
+ if (expectLeadingCrlf)
+ {
+ BoundaryBytes = Encoding.UTF8.GetBytes("\r\n--" + boundary);
+ }
+ else
+ {
+ BoundaryBytes = Encoding.UTF8.GetBytes("--" + boundary);
+ }
+ FinalBoundaryLength = BoundaryBytes.Length + 2; // Include the final '--' terminator.
+
+ var length = BoundaryBytes.Length;
+ for (var i = 0; i < _skipTable.Length; ++i)
+ {
+ _skipTable[i] = length;
+ }
+ for (var i = 0; i < length; ++i)
+ {
+ _skipTable[BoundaryBytes[i]] = Math.Max(1, length - 1 - i);
+ }
+ }
+
+ public int GetSkipValue(byte input)
+ {
+ return _skipTable[input];
+ }
+
+ public bool ExpectLeadingCrlf
+ {
+ get { return _expectLeadingCrlf; }
+ set
+ {
+ if (value != _expectLeadingCrlf)
+ {
+ _expectLeadingCrlf = value;
+ Initialize(_boundary, _expectLeadingCrlf);
+ }
+ }
+ }
+
+ public byte[] BoundaryBytes { get; private set; } = default!; // This gets initialized as part of Initialize called from in the ctor.
+
+ public int FinalBoundaryLength { get; private set; }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/MultipartFormDataContent.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs
similarity index 80%
rename from sdk/core/Azure.Core/src/MultipartFormDataContent.cs
rename to sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs
index e83bfc2d567b..3604f62cb2a6 100644
--- a/sdk/core/Azure.Core/src/MultipartFormDataContent.cs
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartContent.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
@@ -8,43 +8,58 @@
using System.Threading;
using System.Threading.Tasks;
+#nullable disable
+
namespace Azure.Core
{
///
/// Provides a container for content encoded using multipart/form-data MIME type.
///
- internal class MultipartFormDataContent : RequestContent
+ internal class MultipartContent : RequestContent
{
#region Fields
private const string CrLf = "\r\n";
- private const string FormData = "form-data";
+ private const string ColonSP = ": ";
private static readonly int s_crlfLength = GetEncodedLength(CrLf);
private static readonly int s_dashDashLength = GetEncodedLength("--");
- private static readonly int s_colonSpaceLength = GetEncodedLength(": ");
+ private static readonly int s_colonSpaceLength = GetEncodedLength(ColonSP);
private readonly List _nestedContent;
+ private readonly string _subtype;
private readonly string _boundary;
+ internal readonly Dictionary _headers;
#endregion Fields
#region Construction
- ///
- /// Initializes a new instance of the class.
- ///
- public MultipartFormDataContent() : this(GetDefaultBoundary())
+ public MultipartContent()
+ : this("mixed", GetDefaultBoundary())
+ { }
+
+ public MultipartContent(string subtype)
+ : this(subtype, GetDefaultBoundary())
{ }
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class.
///
+ /// The multipart sub type.
/// The boundary string for the multipart form data content.
- public MultipartFormDataContent(string boundary)
+ public MultipartContent(string subtype, string boundary)
{
ValidateBoundary(boundary);
- _boundary = boundary;
+ _subtype = subtype;
+
+ // see https://www.ietf.org/rfc/rfc1521.txt page 29.
+ _boundary = boundary.Contains(":") ? $"\"{boundary}\"" : boundary;
+ _headers = new Dictionary
+ {
+ [HttpHeader.Names.ContentType] = $"multipart/{_subtype}; boundary={_boundary}"
+ };
+
_nestedContent = new List();
}
@@ -97,7 +112,7 @@ private static string GetDefaultBoundary()
/// The request.
public void ApplyToRequest(Request request)
{
- request.Headers.Add("Content-Type", $"multipart/form-data;boundary=\"{_boundary}\"");
+ request.Headers.Add(HttpHeader.Names.ContentType, $"multipart/{_subtype}; boundary={_boundary}");
}
///
@@ -105,10 +120,10 @@ public void ApplyToRequest(Request request)
/// get serialized to multipart/form-data MIME type.
///
/// The Request content to add to the collection.
- public void Add(RequestContent content)
+ public virtual void Add(RequestContent content)
{
Argument.AssertNotNull(content, nameof(content));
- AddInternal(content, null, null, null);
+ AddInternal(content, null);
}
///
@@ -117,69 +132,20 @@ public void Add(RequestContent content)
///
/// The Request content to add to the collection.
/// The headers to add to the collection.
- public void Add(RequestContent content, Dictionary headers)
+ public virtual void Add(RequestContent content, Dictionary headers)
{
Argument.AssertNotNull(content, nameof(content));
Argument.AssertNotNull(headers, nameof(headers));
- AddInternal(content, headers, null, null);
+ AddInternal(content, headers);
}
- ///
- /// Add HTTP content to a collection of RequestContent objects that
- /// get serialized to multipart/form-data MIME type.
- ///
- /// The Request content to add to the collection.
- /// The name for the request content to add.
- /// The headers to add to the collection.
- public void Add(RequestContent content, string name, Dictionary? headers)
- {
- Argument.AssertNotNull(content, nameof(content));
- Argument.AssertNotNullOrWhiteSpace(name, nameof(name));
-
- AddInternal(content, headers, name, null);
- }
-
- ///
- /// Add HTTP content to a collection of RequestContent objects that
- /// get serialized to multipart/form-data MIME type.
- ///
- /// The Request content to add to the collection.
- /// The name for the request content to add.
- /// The file name for the reuest content to add to the collection.
- /// The headers to add to the collection.
- public void Add(RequestContent content, string name, string fileName, Dictionary? headers)
- {
- Argument.AssertNotNull(content, nameof(content));
- Argument.AssertNotNullOrWhiteSpace(name, nameof(name));
- Argument.AssertNotNullOrWhiteSpace(fileName, nameof(fileName));
-
- AddInternal(content, headers, name, fileName);
- }
-
- private void AddInternal(RequestContent content, Dictionary? headers, string? name, string? fileName)
+ private void AddInternal(RequestContent content, Dictionary headers)
{
if (headers == null)
{
headers = new Dictionary();
}
-
- if (!headers.ContainsKey("Content-Disposition"))
- {
- var value = FormData;
-
- if (name != null)
- {
- value = value + "; name=" + name;
- }
- if (fileName != null)
- {
- value = value + "; filename=" + fileName;
- }
-
- headers.Add("Content-Disposition", value);
- }
-
_nestedContent.Add(new MultipartRequestContent(content, headers));
}
@@ -188,7 +154,7 @@ private void AddInternal(RequestContent content, Dictionary? hea
#region Dispose
///
- /// Frees resources held by the object.
+ /// Frees resources held by the object.
///
public override void Dispose()
{
@@ -197,7 +163,6 @@ public override void Dispose()
content.RequestContent.Dispose();
}
_nestedContent.Clear();
-
}
#endregion Dispose
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs
new file mode 100644
index 000000000000..b99ee6ef83a6
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartFormDataContent.cs
@@ -0,0 +1,122 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System.Collections.Generic;
+
+#nullable disable
+
+namespace Azure.Core
+{
+ ///
+ /// Provides a container for content encoded using multipart/form-data MIME type.
+ ///
+ internal class MultipartFormDataContent : MultipartContent
+ {
+ #region Fields
+
+ private const string FormData = "form-data";
+
+ #endregion Fields
+
+ #region Construction
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public MultipartFormDataContent() : base(FormData)
+ { }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The boundary string for the multipart form data content.
+ public MultipartFormDataContent(string boundary) : base(FormData, boundary)
+ { }
+
+ #endregion Construction
+
+ ///
+ /// Add HTTP content to a collection of RequestContent objects that
+ /// get serialized to multipart/form-data MIME type.
+ ///
+ /// The Request content to add to the collection.
+ public override void Add(RequestContent content)
+ {
+ Argument.AssertNotNull(content, nameof(content));
+ AddInternal(content, null, null, null);
+ }
+
+ ///
+ /// Add HTTP content to a collection of RequestContent objects that
+ /// get serialized to multipart/form-data MIME type.
+ ///
+ /// The Request content to add to the collection.
+ /// The headers to add to the collection.
+ public override void Add(RequestContent content, Dictionary headers)
+ {
+ Argument.AssertNotNull(content, nameof(content));
+ Argument.AssertNotNull(headers, nameof(headers));
+
+ AddInternal(content, headers, null, null);
+ }
+
+ ///
+ /// Add HTTP content to a collection of RequestContent objects that
+ /// get serialized to multipart/form-data MIME type.
+ ///
+ /// The Request content to add to the collection.
+ /// The name for the request content to add.
+ /// The headers to add to the collection.
+ public void Add(RequestContent content, string name, Dictionary headers)
+ {
+ Argument.AssertNotNull(content, nameof(content));
+ Argument.AssertNotNullOrWhiteSpace(name, nameof(name));
+
+ AddInternal(content, headers, name, null);
+ }
+
+ ///
+ /// Add HTTP content to a collection of RequestContent objects that
+ /// get serialized to multipart/form-data MIME type.
+ ///
+ /// The Request content to add to the collection.
+ /// The name for the request content to add.
+ /// The file name for the request content to add to the collection.
+ /// The headers to add to the collection.
+ public void Add(RequestContent content, string name, string fileName, Dictionary headers)
+ {
+ Argument.AssertNotNull(content, nameof(content));
+ Argument.AssertNotNullOrWhiteSpace(name, nameof(name));
+ Argument.AssertNotNullOrWhiteSpace(fileName, nameof(fileName));
+
+ AddInternal(content, headers, name, fileName);
+ }
+
+ private void AddInternal(RequestContent content, Dictionary headers, string name, string fileName)
+ {
+ if (headers == null)
+ {
+ headers = new Dictionary();
+ }
+
+ if (!headers.ContainsKey("Content-Disposition"))
+ {
+ var value = FormData;
+
+ if (name != null)
+ {
+ value = value + "; name=" + name;
+ }
+ if (fileName != null)
+ {
+ value = value + "; filename=" + fileName;
+ }
+
+ headers.Add("Content-Disposition", value);
+ }
+
+ base.Add(content, headers);
+ }
+
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs
new file mode 100644
index 000000000000..6ac52f6aec03
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReader.cs
@@ -0,0 +1,122 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+#pragma warning disable CA1001 // disposable fields
+#pragma warning disable IDE0008 // Use explicit type
+#nullable disable
+
+namespace Azure.Core
+{
+ // https://www.ietf.org/rfc/rfc2046.txt
+ internal class MultipartReader
+ {
+ public const int DefaultHeadersCountLimit = 16;
+ public const int DefaultHeadersLengthLimit = 1024 * 16;
+ private const int DefaultBufferSize = 1024 * 4;
+
+ private readonly BufferedReadStream _stream;
+ private readonly MultipartBoundary _boundary;
+ private MultipartReaderStream _currentStream;
+
+ public MultipartReader(string boundary, Stream stream)
+ : this(boundary, stream, DefaultBufferSize)
+ {
+ }
+
+ public MultipartReader(string boundary, Stream stream, int bufferSize)
+ {
+ if (boundary == null)
+ {
+ throw new ArgumentNullException(nameof(boundary));
+ }
+
+ if (stream == null)
+ {
+ throw new ArgumentNullException(nameof(stream));
+ }
+
+ if (bufferSize < boundary.Length + 8) // Size of the boundary + leading and trailing CRLF + leading and trailing '--' markers.
+ {
+ throw new ArgumentOutOfRangeException(nameof(bufferSize), bufferSize, "Insufficient buffer space, the buffer must be larger than the boundary: " + boundary);
+ }
+ _stream = new BufferedReadStream(stream, bufferSize);
+ _boundary = new MultipartBoundary(boundary, false);
+ // This stream will drain any preamble data and remove the first boundary marker.
+ // TODO: HeadersLengthLimit can't be modified until after the constructor.
+ _currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = HeadersLengthLimit };
+ }
+
+ ///
+ /// The limit for the number of headers to read.
+ ///
+ public int HeadersCountLimit { get; set; } = DefaultHeadersCountLimit;
+
+ ///
+ /// The combined size limit for headers per multipart section.
+ ///
+ public int HeadersLengthLimit { get; set; } = DefaultHeadersLengthLimit;
+
+ ///
+ /// The optional limit for the total response body length.
+ ///
+ public long? BodyLengthLimit { get; set; }
+
+ public async Task ReadNextSectionAsync(CancellationToken cancellationToken = new CancellationToken())
+ {
+ // Drain the prior section.
+ await _currentStream.DrainAsync(cancellationToken).ConfigureAwait(false);
+ // If we're at the end return null
+ if (_currentStream.FinalBoundaryFound)
+ {
+ // There may be trailer data after the last boundary.
+ await _stream.DrainAsync(HeadersLengthLimit, cancellationToken).ConfigureAwait(false);
+ return null;
+ }
+ var headers = await ReadHeadersAsync(cancellationToken).ConfigureAwait(false);
+ _boundary.ExpectLeadingCrlf = true;
+ _currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = BodyLengthLimit };
+ long? baseStreamOffset = _stream.CanSeek ? (long?)_stream.Position : null;
+ return new MultipartSection() { Headers = headers, Body = _currentStream, BaseStreamOffset = baseStreamOffset };
+ }
+
+ private async Task> ReadHeadersAsync(CancellationToken cancellationToken)
+ {
+ int totalSize = 0;
+ var accumulator = new KeyValueAccumulator();
+ var line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken).ConfigureAwait(false);
+ while (!string.IsNullOrEmpty(line))
+ {
+ if (HeadersLengthLimit - totalSize < line.Length)
+ {
+ throw new InvalidDataException($"Multipart headers length limit {HeadersLengthLimit} exceeded.");
+ }
+ totalSize += line.Length;
+ int splitIndex = line.IndexOf(':');
+ if (splitIndex <= 0)
+ {
+ throw new InvalidDataException($"Invalid header line: {line}");
+ }
+
+ var name = line.Substring(0, splitIndex);
+ var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim();
+ accumulator.Append(name, value);
+ if (accumulator.KeyCount > HeadersCountLimit)
+ {
+ throw new InvalidDataException($"Multipart headers count limit {HeadersCountLimit} exceeded.");
+ }
+
+ line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken).ConfigureAwait(false);
+ }
+
+ return accumulator.GetResults();
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs
new file mode 100644
index 000000000000..157200bc2ed7
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartReaderStream.cs
@@ -0,0 +1,348 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System;
+using System.Buffers;
+using System.Diagnostics;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+#pragma warning disable IDE0008 // Use explicit type
+#pragma warning disable IDE0016 // Simplify null check
+#pragma warning disable IDE0018 // Inline declaration
+#pragma warning disable IDE0054 // Use compound assignment
+#pragma warning disable IDE0059 // Unnecessary assignment
+
+namespace Azure.Core
+{
+ internal sealed class MultipartReaderStream : Stream
+ {
+ private readonly MultipartBoundary _boundary;
+ private readonly BufferedReadStream _innerStream;
+ private readonly ArrayPool _bytePool;
+
+ private readonly long _innerOffset;
+ private long _position;
+ private long _observedLength;
+ private bool _finished;
+
+ ///
+ /// Creates a stream that reads until it reaches the given boundary pattern.
+ ///
+ /// The .
+ /// The boundary pattern to use.
+ public MultipartReaderStream(BufferedReadStream stream, MultipartBoundary boundary)
+ : this(stream, boundary, ArrayPool.Shared)
+ {
+ }
+
+ ///
+ /// Creates a stream that reads until it reaches the given boundary pattern.
+ ///
+ /// The .
+ /// The boundary pattern to use.
+ /// The ArrayPool pool to use for temporary byte arrays.
+ public MultipartReaderStream(BufferedReadStream stream, MultipartBoundary boundary, ArrayPool bytePool)
+ {
+ if (stream == null)
+ {
+ throw new ArgumentNullException(nameof(stream));
+ }
+
+ if (boundary == null)
+ {
+ throw new ArgumentNullException(nameof(boundary));
+ }
+
+ _bytePool = bytePool;
+ _innerStream = stream;
+ _innerOffset = _innerStream.CanSeek ? _innerStream.Position : 0;
+ _boundary = boundary;
+ }
+
+ public bool FinalBoundaryFound { get; private set; }
+
+ public long? LengthLimit { get; set; }
+
+ public override bool CanRead
+ {
+ get { return true; }
+ }
+
+ public override bool CanSeek
+ {
+ get { return _innerStream.CanSeek; }
+ }
+
+ public override bool CanWrite
+ {
+ get { return false; }
+ }
+
+ public override long Length
+ {
+ get { return _observedLength; }
+ }
+
+ public override long Position
+ {
+ get { return _position; }
+ set
+ {
+ if (value < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(value), value, "The Position must be positive.");
+ }
+ if (value > _observedLength)
+ {
+ throw new ArgumentOutOfRangeException(nameof(value), value, "The Position must be less than length.");
+ }
+ _position = value;
+ if (_position < _observedLength)
+ {
+ _finished = false;
+ }
+ }
+ }
+
+ public override long Seek(long offset, SeekOrigin origin)
+ {
+ if (origin == SeekOrigin.Begin)
+ {
+ Position = offset;
+ }
+ else if (origin == SeekOrigin.Current)
+ {
+ Position = Position + offset;
+ }
+ else // if (origin == SeekOrigin.End)
+ {
+ Position = Length + offset;
+ }
+ return Position;
+ }
+
+ public override void SetLength(long value)
+ {
+ throw new NotSupportedException();
+ }
+
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ throw new NotSupportedException();
+ }
+
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ throw new NotSupportedException();
+ }
+
+ public override void Flush()
+ {
+ throw new NotSupportedException();
+ }
+
+ private void PositionInnerStream()
+ {
+ if (_innerStream.CanSeek && _innerStream.Position != (_innerOffset + _position))
+ {
+ _innerStream.Position = _innerOffset + _position;
+ }
+ }
+
+ private int UpdatePosition(int read)
+ {
+ _position += read;
+ if (_observedLength < _position)
+ {
+ _observedLength = _position;
+ if (LengthLimit.HasValue && _observedLength > LengthLimit.GetValueOrDefault())
+ {
+ throw new InvalidDataException($"Multipart body length limit {LengthLimit.GetValueOrDefault()} exceeded.");
+ }
+ }
+ return read;
+ }
+
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ if (_finished)
+ {
+ return 0;
+ }
+
+ PositionInnerStream();
+ if (!_innerStream.EnsureBuffered(_boundary.FinalBoundaryLength))
+ {
+ throw new IOException("Unexpected end of Stream, the content may have already been read by another component. ");
+ }
+ var bufferedData = _innerStream.BufferedData;
+
+ // scan for a boundary match, full or partial.
+ int read;
+ if (SubMatch(bufferedData, _boundary.BoundaryBytes, out var matchOffset, out var matchCount))
+ {
+ // We found a possible match, return any data before it.
+ if (matchOffset > bufferedData.Offset)
+ {
+ read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset));
+ return UpdatePosition(read);
+ }
+
+ var length = _boundary.BoundaryBytes.Length;
+ Debug.Assert(matchCount == length);
+
+ // "The boundary may be followed by zero or more characters of
+ // linear whitespace. It is then terminated by either another CRLF"
+ // or -- for the final boundary.
+ var boundary = _bytePool.Rent(length);
+ read = _innerStream.Read(boundary, 0, length);
+ _bytePool.Return(boundary);
+ Debug.Assert(read == length); // It should have all been buffered
+
+ var remainder = _innerStream.ReadLine(lengthLimit: 100); // Whitespace may exceed the buffer.
+ remainder = remainder.Trim();
+ if (string.Equals("--", remainder, StringComparison.Ordinal))
+ {
+ FinalBoundaryFound = true;
+ }
+#pragma warning disable CA1820 // Test for empty strings using string length
+ Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder);
+#pragma warning restore CA1820 // Test for empty strings using string length
+ _finished = true;
+ return 0;
+ }
+
+ // No possible boundary match within the buffered data, return the data from the buffer.
+ read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count));
+ return UpdatePosition(read);
+ }
+
+ public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ if (_finished)
+ {
+ return 0;
+ }
+
+ PositionInnerStream();
+ if (!await _innerStream.EnsureBufferedAsync(_boundary.FinalBoundaryLength, cancellationToken).ConfigureAwait(false))
+ {
+ throw new IOException("Unexpected end of Stream, the content may have already been read by another component. ");
+ }
+ var bufferedData = _innerStream.BufferedData;
+
+ // scan for a boundary match, full or partial.
+ int matchOffset;
+ int matchCount;
+ int read;
+ if (SubMatch(bufferedData, _boundary.BoundaryBytes, out matchOffset, out matchCount))
+ {
+ // We found a possible match, return any data before it.
+ if (matchOffset > bufferedData.Offset)
+ {
+ // Sync, it's already buffered
+ read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset));
+ return UpdatePosition(read);
+ }
+
+ var length = _boundary.BoundaryBytes!.Length;
+ Debug.Assert(matchCount == length);
+
+ // "The boundary may be followed by zero or more characters of
+ // linear whitespace. It is then terminated by either another CRLF"
+ // or -- for the final boundary.
+ var boundary = _bytePool.Rent(length);
+ read = _innerStream.Read(boundary, 0, length);
+ _bytePool.Return(boundary);
+ Debug.Assert(read == length); // It should have all been buffered
+
+ var remainder = await _innerStream.ReadLineAsync(lengthLimit: 100, cancellationToken: cancellationToken).ConfigureAwait(false); // Whitespace may exceed the buffer.
+ remainder = remainder.Trim();
+ if (string.Equals("--", remainder, StringComparison.Ordinal))
+ {
+ FinalBoundaryFound = true;
+ }
+#pragma warning disable CA1820 // Test for empty strings using string length
+ Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder);
+#pragma warning restore CA1820 // Test for empty strings using string length
+
+ _finished = true;
+ return 0;
+ }
+
+ // No possible boundary match within the buffered data, return the data from the buffer.
+ read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count));
+ return UpdatePosition(read);
+ }
+
+ // Does segment1 contain all of matchBytes, or does it end with the start of matchBytes?
+ // 1: AAAAABBBBBCCCCC
+ // 2: BBBBB
+ // Or:
+ // 1: AAAAABBB
+ // 2: BBBBB
+ private bool SubMatch(ArraySegment segment1, byte[] matchBytes, out int matchOffset, out int matchCount)
+ {
+ // clear matchCount to zero
+ matchCount = 0;
+
+ // case 1: does segment1 fully contain matchBytes?
+ {
+ var matchBytesLengthMinusOne = matchBytes.Length - 1;
+ var matchBytesLastByte = matchBytes[matchBytesLengthMinusOne];
+ var segmentEndMinusMatchBytesLength = segment1.Offset + segment1.Count - matchBytes.Length;
+
+ matchOffset = segment1.Offset;
+ while (matchOffset < segmentEndMinusMatchBytesLength)
+ {
+ var lookaheadTailChar = segment1.Array![matchOffset + matchBytesLengthMinusOne];
+ if (lookaheadTailChar == matchBytesLastByte &&
+ CompareBuffers(segment1.Array, matchOffset, matchBytes, 0, matchBytesLengthMinusOne) == 0)
+ {
+ matchCount = matchBytes.Length;
+ return true;
+ }
+ matchOffset += _boundary.GetSkipValue(lookaheadTailChar);
+ }
+ }
+
+ // case 2: does segment1 end with the start of matchBytes?
+ var segmentEnd = segment1.Offset + segment1.Count;
+
+ matchCount = 0;
+ for (; matchOffset < segmentEnd; matchOffset++)
+ {
+ var countLimit = segmentEnd - matchOffset;
+ for (matchCount = 0; matchCount < matchBytes.Length && matchCount < countLimit; matchCount++)
+ {
+ if (matchBytes[matchCount] != segment1.Array![matchOffset + matchCount])
+ {
+ matchCount = 0;
+ break;
+ }
+ }
+ if (matchCount > 0)
+ {
+ break;
+ }
+ }
+ return matchCount > 0;
+ }
+
+ private static int CompareBuffers(byte[] buffer1, int offset1, byte[] buffer2, int offset2, int count)
+ {
+ for (; count-- > 0; offset1++, offset2++)
+ {
+ if (buffer1[offset1] != buffer2[offset2])
+ {
+ return buffer1[offset1] - buffer2[offset2];
+ }
+ }
+ return 0;
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs
new file mode 100644
index 000000000000..5849be42c018
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/MultipartSection.cs
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+namespace Azure.Core
+{
+ internal class MultipartSection
+ {
+ public Dictionary Headers { get; set; }
+
+ ///
+ /// Gets or sets the body.
+ ///
+ public Stream Body { get; set; } = default!;
+
+ ///
+ /// The position where the body starts in the total multipart body.
+ /// This may not be available if the total multipart body is not seekable.
+ ///
+ public long? BaseStreamOffset { get; set; }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs b/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs
new file mode 100644
index 000000000000..6be8b9cf138c
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/Multipart/StreamHelperExtensions.cs
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Copied from https://github.com/aspnet/AspNetCore/tree/master/src/Http/WebUtilities/src
+
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+#pragma warning disable IDE1006 // Prefix _ unexpected
+
+namespace Azure.Core
+{
+ internal static class StreamHelperExtensions
+ {
+ private const int _maxReadBufferSize = 1024 * 4;
+
+ public static Task DrainAsync(this Stream stream, CancellationToken cancellationToken)
+ {
+ return stream.DrainAsync(ArrayPool.Shared, null, cancellationToken);
+ }
+
+ public static Task DrainAsync(this Stream stream, long? limit, CancellationToken cancellationToken)
+ {
+ return stream.DrainAsync(ArrayPool.Shared, limit, cancellationToken);
+ }
+
+ public static async Task DrainAsync(this Stream stream, ArrayPool bytePool, long? limit, CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ var buffer = bytePool.Rent(_maxReadBufferSize);
+ long total = 0;
+ try
+ {
+ var read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
+ while (read > 0)
+ {
+ // Not all streams support cancellation directly.
+ cancellationToken.ThrowIfCancellationRequested();
+ if (limit.HasValue && limit.GetValueOrDefault() - total < read)
+ {
+ throw new InvalidDataException($"The stream exceeded the data limit {limit.GetValueOrDefault()}.");
+ }
+ total += read;
+ read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ finally
+ {
+ bytePool.Return(buffer);
+ }
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs b/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs
new file mode 100644
index 000000000000..5ab190bd6846
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/RequestRequestContent.cs
@@ -0,0 +1,166 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+#nullable disable
+
+namespace Azure.Core
+{
+ ///
+ /// Provides a container for content encoded using multipart/form-data MIME type.
+ ///
+ internal class RequestRequestContent : RequestContent
+ {
+ #region Fields
+
+ private const string SP = " ";
+ private const string ColonSP = ": ";
+ private const string CRLF = "\r\n";
+ private const int DefaultHeaderAllocation = 2 * 1024;
+
+ private readonly Request _request;
+
+ #endregion Fields
+
+ #region Construction
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The instance to encapsulate.
+ public RequestRequestContent(Request request)
+ {
+ Argument.AssertNotNull(request, nameof(request));
+ this._request = request;
+ }
+
+ #endregion Construction
+
+ #region Dispose
+
+ ///
+ /// Frees resources held by the object.
+ ///
+ public override void Dispose()
+ {
+ _request.Dispose();
+ }
+
+ #endregion Dispose
+
+ #region Serialization
+
+ public override async Task WriteToAsync(Stream stream, CancellationToken cancellationToken)
+ {
+ Argument.AssertNotNull(stream, nameof(stream));
+
+ byte[] header = SerializeHeader();
+ await stream.WriteAsync(header, 0, header.Length).ConfigureAwait(false);
+
+ if (_request.Content != null)
+ {
+ await _request.Content.WriteToAsync(stream, cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ public override void WriteTo(Stream stream, CancellationToken cancellationToken)
+ {
+ Argument.AssertNotNull(stream, nameof(stream));
+
+ byte[] header = SerializeHeader();
+ stream.Write(header, 0, header.Length);
+
+ if (_request.Content != null)
+ {
+ _request.Content.WriteTo(stream, cancellationToken);
+ }
+ }
+
+ ///
+ /// Computes the length of the stream if possible.
+ ///
+ /// The computed length of the stream.
+ /// true if the length has been computed; otherwise false.
+ public override bool TryComputeLength(out long length)
+ {
+ // We have four states we could be in:
+ // 1. We have content and the stream came back as a null or non-seekable
+ // 2. We have content and the stream is seekable, so we know its length
+ // 3. We don't have content
+ //
+ // For #1, we return false.
+ // For #2, we return true & the size of our headers + the content length
+ // For #3, we return true & the size of our headers
+
+ bool hasContent = _request.Content != null;
+ length = 0;
+
+ // Cases #1, #2, #3
+ if (hasContent)
+ {
+ if (!_request!.Content!.TryComputeLength(out length))
+ {
+ length = 0;
+ return false;
+ }
+ }
+
+ // We serialize header to a StringBuilder so that we can determine the length
+ // following the pattern for HttpContent to try and determine the message length.
+ // The perf overhead is no larger than for the other HttpContent implementations.
+ byte[] header = SerializeHeader();
+ length += header.Length;
+ return true;
+ }
+
+ private byte[] SerializeHeader()
+ {
+ StringBuilder message = new StringBuilder(DefaultHeaderAllocation);
+
+ SerializeRequestLine(message, _request);
+ SerializeHeaderFields(message, _request.Headers);
+ message.Append(CRLF);
+ return Encoding.UTF8.GetBytes(message.ToString());
+ }
+
+ ///
+ /// Serializes the HTTP request line.
+ ///
+ /// Where to write the request line.
+ /// The HTTP request.
+ private static void SerializeRequestLine(StringBuilder message, Request request)
+ {
+ Argument.AssertNotNull(message, nameof(message));
+ message.Append(request.Method + SP);
+ message.Append(request.Uri.ToString() + SP);
+ message.Append("HTTP/1.1" + CRLF);
+
+ // Only insert host header if not already present.
+ if (!request.Headers.TryGetValue("Host", out _))
+ {
+ message.Append("Host" + ColonSP + request.Uri.Host + CRLF);
+ }
+ }
+
+ ///
+ /// Serializes the header fields.
+ ///
+ /// Where to write the status line.
+ /// The headers to write.
+ private static void SerializeHeaderFields(StringBuilder message, RequestHeaders headers)
+ {
+ Argument.AssertNotNull(message, nameof(message));
+ foreach (HttpHeader header in headers)
+ {
+ message.Append(header.Name + ColonSP + header.Value + CRLF);
+ }
+ }
+
+ #endregion Serialization
+ }
+}
diff --git a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
index a7a10b64fb89..7a6449b58fb6 100644
--- a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
+++ b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
@@ -2,6 +2,7 @@
{84491222-6C36-4FA7-BBAE-1FA804129151}
$(RequiredTargetFrameworks)
+ $(DefineConstants);HAS_INTERNALS_VISIBLE_CORE
true
@@ -22,13 +23,17 @@
+
+
+
-
-
+
+
diff --git a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs
index 92f490faa4e3..a80324172f01 100644
--- a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs
+++ b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs
@@ -20,11 +20,11 @@ namespace Azure.Core.Tests
[TestFixture(typeof(HttpWebRequestTransport), true)]
[TestFixture(typeof(HttpWebRequestTransport), false)]
#endif
- public class HttpPipelineFunctionalTests: PipelineTestBase
+ public class HttpPipelineFunctionalTests : PipelineTestBase
{
private readonly Type _transportType;
- public HttpPipelineFunctionalTests(Type transportType, bool isAsync): base(isAsync)
+ public HttpPipelineFunctionalTests(Type transportType, bool isAsync) : base(isAsync)
{
_transportType = transportType;
}
@@ -32,7 +32,7 @@ public HttpPipelineFunctionalTests(Type transportType, bool isAsync): base(isAsy
private TestOptions GetOptions()
{
var options = new TestOptions();
- options.Transport = (HttpPipelineTransport) Activator.CreateInstance(_transportType);
+ options.Transport = (HttpPipelineTransport)Activator.CreateInstance(_transportType);
return options;
}
diff --git a/sdk/core/Azure.Core/tests/MultipartReaderTests.cs b/sdk/core/Azure.Core/tests/MultipartReaderTests.cs
new file mode 100644
index 000000000000..a5386f68e348
--- /dev/null
+++ b/sdk/core/Azure.Core/tests/MultipartReaderTests.cs
@@ -0,0 +1,385 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#nullable disable warnings
+
+using System;
+using System.IO;
+using System.Text;
+using System.Threading.Tasks;
+using NUnit.Framework;
+
+namespace Azure.Core.Tests
+{
+ public class MultipartReaderTests
+ {
+ private const string Boundary = "9051914041544843365972754266";
+ // Note that CRLF (\r\n) is required. You can't use multi-line C# strings here because the line breaks on Linux are just LF.
+ private const string OnePartBody =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--\r\n";
+ private const string OnePartBodyTwoHeaders =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"Custom-header: custom-value\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--\r\n";
+ private const string OnePartBodyWithTrailingWhitespace =
+"--9051914041544843365972754266 \r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--\r\n";
+ // It's non-compliant but common to leave off the last CRLF.
+ private const string OnePartBodyWithoutFinalCRLF =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--";
+ private const string TwoPartBody =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" +
+"Content-Type: text/plain\r\n" +
+"\r\n" +
+"Content of a.txt.\r\n" +
+"\r\n" +
+"--9051914041544843365972754266--\r\n";
+ private const string TwoPartBodyWithUnicodeFileName =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"file1\"; filename=\"a色.txt\"\r\n" +
+"Content-Type: text/plain\r\n" +
+"\r\n" +
+"Content of a.txt.\r\n" +
+"\r\n" +
+"--9051914041544843365972754266--\r\n";
+ private const string ThreePartBody =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" +
+"Content-Type: text/plain\r\n" +
+"\r\n" +
+"Content of a.txt.\r\n" +
+"\r\n" +
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r\n" +
+"Content-Type: text/html\r\n" +
+"\r\n" +
+"Content of a.html.\r\n" +
+"\r\n" +
+"--9051914041544843365972754266--\r\n";
+
+ private const string TwoPartBodyIncompleteBuffer =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" +
+"Content-Type: text/plain\r\n" +
+"\r\n" +
+"Content of a.txt.\r\n" +
+"\r\n" +
+"--9051914041544843365";
+
+ private static MemoryStream MakeStream(string text)
+ {
+ return new MemoryStream(Encoding.UTF8.GetBytes(text));
+ }
+
+ private static string GetString(byte[] buffer, int count)
+ {
+ return Encoding.ASCII.GetString(buffer, 0, count);
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadSinglePartBody_Success()
+ {
+ var stream = MakeStream(OnePartBody);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public void MultipartReader_HeaderCountExceeded_Throws()
+ {
+ var stream = MakeStream(OnePartBodyTwoHeaders);
+ var reader = new MultipartReader(Boundary, stream)
+ {
+ HeadersCountLimit = 1,
+ };
+
+ var exception = Assert.ThrowsAsync(async () => await reader.ReadNextSectionAsync());
+ Assert.That(exception.Message, Is.EqualTo("Multipart headers count limit 1 exceeded."));
+ }
+
+ [Test]
+ public void MultipartReader_HeadersLengthExceeded_Throws()
+ {
+ var stream = MakeStream(OnePartBodyTwoHeaders);
+ var reader = new MultipartReader(Boundary, stream)
+ {
+ HeadersLengthLimit = 60,
+ };
+
+ var exception = Assert.ThrowsAsync(() => reader.ReadNextSectionAsync());
+ Assert.That(exception.Message, Is.EqualTo("Line length limit 17 exceeded."));
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadSinglePartBodyWithTrailingWhitespace_Success()
+ {
+ var stream = MakeStream(OnePartBodyWithTrailingWhitespace);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadSinglePartBodyWithoutLastCRLF_Success()
+ {
+ var stream = MakeStream(OnePartBodyWithoutFinalCRLF);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadTwoPartBody_Success()
+ {
+ var stream = MakeStream(TwoPartBody);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(2));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\""));
+ Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain"));
+ buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadTwoPartBodyWithUnicodeFileName_Success()
+ {
+ var stream = MakeStream(TwoPartBodyWithUnicodeFileName);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(2));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a色.txt\""));
+ Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain"));
+ buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public async Task MultipartReader_ThreePartBody_Success()
+ {
+ var stream = MakeStream(ThreePartBody);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(2));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\""));
+ Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain"));
+ buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.txt.\r\n"));
+
+ section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(2));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file2\"; filename=\"a.html\""));
+ Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/html"));
+ buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("Content of a.html.\r\n"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public void MultipartReader_BufferSizeMustBeLargerThanBoundary_Throws()
+ {
+ var stream = MakeStream(ThreePartBody);
+ Assert.Throws(() =>
+ {
+ var reader = new MultipartReader(Boundary, stream, 5);
+ });
+ }
+
+ [Test]
+ public async Task MultipartReader_TwoPartBodyIncompleteBuffer_TwoSectionsReadSuccessfullyThirdSectionThrows()
+ {
+ var stream = MakeStream(TwoPartBodyIncompleteBuffer);
+ var reader = new MultipartReader(Boundary, stream);
+ var buffer = new byte[128];
+
+ //first section can be read successfully
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\""));
+ var read = section.Body.Read(buffer, 0, buffer.Length);
+ Assert.That(GetString(buffer, read), Is.EqualTo("text default"));
+
+ //second section can be read successfully (even though the bottom boundary is truncated)
+ section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(2));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"file1\"; filename=\"a.txt\""));
+ Assert.That(section.Headers["Content-Type"][0], Is.EqualTo("text/plain"));
+ read = section.Body.Read(buffer, 0, buffer.Length);
+ Assert.That(GetString(buffer, read), Is.EqualTo("Content of a.txt.\r\n"));
+
+ Assert.ThrowsAsync(async () =>
+ {
+ // we'll be unable to ensure enough bytes are buffered to even contain a final boundary
+ section = await reader.ReadNextSectionAsync();
+ });
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadInvalidUtf8Header_ReplacementCharacters()
+ {
+ var body1 =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\" filename=\"a";
+
+ var body2 =
+".txt\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--\r\n";
+ var stream = new MemoryStream();
+ var bytes = Encoding.UTF8.GetBytes(body1);
+ stream.Write(bytes, 0, bytes.Length);
+
+ // Write an invalid utf-8 segment in the middle
+ stream.Write(new byte[] { 0xC1, 0x21 }, 0, 2);
+
+ bytes = Encoding.UTF8.GetBytes(body2);
+ stream.Write(bytes, 0, bytes.Length);
+ stream.Seek(0, SeekOrigin.Begin);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\" filename=\"a\uFFFD!.txt\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+
+ [Test]
+ public async Task MultipartReader_ReadInvalidUtf8SurrogateHeader_ReplacementCharacters()
+ {
+ var body1 =
+"--9051914041544843365972754266\r\n" +
+"Content-Disposition: form-data; name=\"text\" filename=\"a";
+
+ var body2 =
+".txt\"\r\n" +
+"\r\n" +
+"text default\r\n" +
+"--9051914041544843365972754266--\r\n";
+ var stream = new MemoryStream();
+ var bytes = Encoding.UTF8.GetBytes(body1);
+ stream.Write(bytes, 0, bytes.Length);
+
+ // Write an invalid utf-8 segment in the middle
+ stream.Write(new byte[] { 0xED, 0xA0, 85 }, 0, 3);
+
+ bytes = Encoding.UTF8.GetBytes(body2);
+ stream.Write(bytes, 0, bytes.Length);
+ stream.Seek(0, SeekOrigin.Begin);
+ var reader = new MultipartReader(Boundary, stream);
+
+ var section = await reader.ReadNextSectionAsync();
+ Assert.NotNull(section);
+ Assert.That(section.Headers.Count, Is.EqualTo(1));
+ Assert.That(section.Headers["Content-Disposition"][0], Is.EqualTo("form-data; name=\"text\" filename=\"a\uFFFDU.txt\""));
+ var buffer = new MemoryStream();
+ await section.Body.CopyToAsync(buffer);
+ Assert.That(Encoding.ASCII.GetString(buffer.ToArray()), Is.EqualTo("text default"));
+
+ Assert.Null(await reader.ReadNextSectionAsync());
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/tests/MultipartTests.cs b/sdk/core/Azure.Core/tests/MultipartTests.cs
new file mode 100644
index 000000000000..a283bec4469c
--- /dev/null
+++ b/sdk/core/Azure.Core/tests/MultipartTests.cs
@@ -0,0 +1,295 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#nullable disable warnings
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Text;
+using System.Threading.Tasks;
+using Azure.Core.TestFramework;
+using NUnit.Framework;
+
+namespace Azure.Core.Tests
+{
+ public class MultipartTests
+ {
+ private const string Boundary = "batchresponse_6040fee7-a2b8-4e78-a674-02086369606a";
+ private const string ContentType = "multipart/mixed; boundary=" + Boundary;
+ private const string Body = "{}";
+
+ // Note that CRLF (\r\n) is required. You can't use multi-line C# strings here because the line breaks on Linux are just LF.
+ private const string TablesOdataBatchResponse =
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a\r\n" +
+"Content-Type: multipart/mixed; boundary=changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" +
+"\r\n" +
+"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" +
+"Content-Type: application/http\r\n" +
+"Content-Transfer-Encoding: binary\r\n" +
+"\r\n" +
+"HTTP/1.1 201 Created\r\n" +
+"DataServiceVersion: 3.0;\r\n" +
+"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" +
+"X-Content-Type-Options: nosniff\r\n" +
+"Cache-Control: no-cache\r\n" +
+"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='01')\r\n" +
+"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" +
+"\r\n" +
+"{}\r\n" +
+"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" +
+"Content-Type: application/http\r\n" +
+"Content-Transfer-Encoding: binary\r\n" +
+"\r\n" +
+"HTTP/1.1 201 Created\r\n" +
+"DataServiceVersion: 3.0;\r\n" +
+"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" +
+"X-Content-Type-Options: nosniff\r\n" +
+"Cache-Control: no-cache\r\n" +
+"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='02')\r\n" +
+"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" +
+"\r\n" +
+"{}\r\n" +
+"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92\r\n" +
+"Content-Type: application/http\r\n" +
+"Content-Transfer-Encoding: binary\r\n" +
+"\r\n" +
+"HTTP/1.1 201 Created\r\n" +
+"DataServiceVersion: 3.0;\r\n" +
+"Content-Type: application/json;odata=fullmetadata;streaming=true;charset=utf-8\r\n" +
+"X-Content-Type-Options: nosniff\r\n" +
+"Cache-Control: no-cache\r\n" +
+"Location: https://mytable.table.core.windows.net/tablename(PartitionKey='somPartition',RowKey='03')\r\n" +
+"ETag: W/\"datetime'2020-08-14T22%3A58%3A57.8328323Z'\"\r\n" +
+"\r\n" +
+"{}\r\n" +
+"--changesetresponse_e52cbca8-7e7f-4d91-a719-f99c69686a92--\r\n" +
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a--\r\n" +
+"";
+
+ private const string BlobBatchResponse =
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" +
+"Content-Type: application/http \r\n" +
+"Content-ID: 0 \r\n" +
+"\r\n" +
+"HTTP/1.1 202 Accepted \r\n" +
+"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f \r\n" +
+"x-ms-version: 2018-11-09 \r\n" +
+"\r\n" +
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" +
+"Content-Type: application/http \r\n" +
+"Content-ID: 1 \r\n" +
+"\r\n" +
+"HTTP/1.1 202 Accepted \r\n" +
+"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2851 \r\n" +
+"x-ms-version: 2018-11-09 \r\n" +
+"\r\n" +
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a \r\n" +
+"Content-Type: application/http \r\n" +
+"Content-ID: 2 \r\n" +
+"\r\n" +
+"HTTP/1.1 404 The specified blob does not exist. \r\n" +
+"x-ms-error-code: BlobNotFound \r\n" +
+"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852 \r\n" +
+"x-ms-version: 2018-11-09 \r\n" +
+"Content-Length: 216 \r\n" +
+"Content-Type: application/xml \r\n" +
+"\r\n" +
+" \r\n" +
+"BlobNotFoundThe specified blob does not exist. \r\n" +
+"RequestId:778fdc83-801e-0000-62ff-0334671e2852 \r\n" +
+"Time:2018-06-14T16:46:54.6040685Z \r\n" +
+"--batchresponse_6040fee7-a2b8-4e78-a674-02086369606a-- \r\n" +
+"0";
+
+
+ [Test]
+ public async Task ParseBatchChangesetResponse()
+ {
+ var stream = MakeStream(TablesOdataBatchResponse);
+ var responses = await Multipart.ParseAsync(stream, ContentType, true, default);
+
+ Assert.That(responses, Is.Not.Null);
+ Assert.That(responses.Length, Is.EqualTo(3));
+ Assert.That(responses.All(r => r.Status == (int)HttpStatusCode.Created));
+
+ foreach (var response in responses)
+ {
+ Assert.That(response.TryGetHeader("DataServiceVersion", out var version));
+ Assert.That(version, Is.EqualTo("3.0;"));
+
+ Assert.That(response.TryGetHeader("Content-Type", out var contentType));
+ Assert.That(contentType, Is.EqualTo("application/json;odata=fullmetadata;streaming=true;charset=utf-8"));
+
+ var bytes = new byte[response.ContentStream.Length];
+ await response.ContentStream.ReadAsync(bytes, 0, bytes.Length);
+ var content = GetString(bytes, bytes.Length);
+
+ Assert.That(content, Is.EqualTo(Body));
+
+ }
+ }
+
+ [Test]
+ public async Task ParseBatchResponse()
+ {
+ var stream = MakeStream(BlobBatchResponse);
+ var responses = await Multipart.ParseAsync(stream, ContentType, true, default);
+
+ Assert.That(responses, Is.Not.Null);
+ Assert.That(responses.Length, Is.EqualTo(3));
+
+ var response = responses[0];
+ Assert.That(response.Status, Is.EqualTo( (int)HttpStatusCode.Accepted));
+ Assert.That(response.TryGetHeader("x-ms-version", out var version));
+ Assert.That(version, Is.EqualTo("2018-11-09"));
+ Assert.That(response.TryGetHeader("x-ms-request-id", out _));
+
+ response = responses[1];
+ Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.Accepted));
+ Assert.That(response.TryGetHeader("x-ms-version", out version));
+ Assert.That(version, Is.EqualTo("2018-11-09"));
+ Assert.That(response.TryGetHeader("x-ms-request-id", out _));
+
+ response = responses[2];
+ Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.NotFound));
+ Assert.That(response.TryGetHeader("x-ms-version", out version));
+ Assert.That(version, Is.EqualTo("2018-11-09"));
+ Assert.That(response.TryGetHeader("x-ms-request-id", out _));
+ var bytes = new byte[response.ContentStream.Length];
+ await response.ContentStream.ReadAsync(bytes, 0, bytes.Length);
+ var content = GetString(bytes, bytes.Length);
+ Assert.That(content.Contains("BlobNotFoundThe specified blob does not exist."));
+ }
+
+ [Test]
+ public async Task SendMultipartData()
+ {
+ const string ApplicationJson = "application/json";
+ const string cteHeaderName = "Content-Transfer-Encoding";
+ const string Binary = "binary";
+ const string Mixed = "mixed";
+ const string ApplicationJsonOdata = "application/json; odata=nometadata";
+ const string DataServiceVersion = "DataServiceVersion";
+ const string Three0 = "3.0";
+ const string Host = "myaccount.table.core.windows.net";
+
+ using Request request = new MockRequest
+ {
+ Method = RequestMethod.Put
+ };
+ request.Uri.Reset(new Uri("https://foo"));
+
+ Guid batchGuid = Guid.NewGuid();
+ var content = new MultipartContent(Mixed, $"batch_{batchGuid}");
+ content.ApplyToRequest(request);
+
+ Guid changesetGuid = Guid.NewGuid();
+ var changeset = new MultipartContent(Mixed, $"changeset_{changesetGuid}");
+ content.Add(changeset, changeset._headers);
+
+ var postReq1 = new MockRequest
+ {
+ Method = RequestMethod.Post
+ };
+ string postUri = $"https://{Host}/Blogs";
+ postReq1.Uri.Reset(new Uri(postUri));
+ postReq1.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata);
+ postReq1.Headers.Add(HttpHeader.Names.Accept, ApplicationJson);
+ postReq1.Headers.Add(DataServiceVersion, Three0);
+ const string post1Body = "{ \"PartitionKey\":\"Channel_19\", \"RowKey\":\"1\", \"Rating\":9, \"Text\":\"Azure...\"}";
+ postReq1.Content = RequestContent.Create(Encoding.UTF8.GetBytes(post1Body));
+ changeset.Add(new RequestRequestContent(postReq1), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } });
+
+ var postReq2 = new MockRequest
+ {
+ Method = RequestMethod.Post
+ };
+ postReq2.Uri.Reset(new Uri(postUri));
+ postReq2.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata);
+ postReq2.Headers.Add(HttpHeader.Names.Accept, ApplicationJson);
+ postReq2.Headers.Add(DataServiceVersion, Three0);
+ const string post2Body = "{ \"PartitionKey\":\"Channel_17\", \"RowKey\":\"2\", \"Rating\":9, \"Text\":\"Azure...\"}";
+ postReq2.Content = RequestContent.Create(Encoding.UTF8.GetBytes(post2Body));
+ changeset.Add(new RequestRequestContent(postReq2), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } });
+
+ var patchReq = new MockRequest
+ {
+ Method = RequestMethod.Patch
+ };
+ string mergeUri = $"https://{Host}/Blogs(PartitionKey='Channel_17',%20RowKey='3')";
+ patchReq.Uri.Reset(new Uri(mergeUri));
+ patchReq.Headers.Add(HttpHeader.Names.ContentType, ApplicationJsonOdata);
+ patchReq.Headers.Add(HttpHeader.Names.Accept, ApplicationJson);
+ patchReq.Headers.Add(DataServiceVersion, Three0);
+ const string patchBody = "{ \"PartitionKey\":\"Channel_19\", \"RowKey\":\"3\", \"Rating\":9, \"Text\":\"Azure Tables...\"}";
+ patchReq.Content = RequestContent.Create(Encoding.UTF8.GetBytes(patchBody));
+ changeset.Add(new RequestRequestContent(patchReq), new Dictionary { { HttpHeader.Names.ContentType, "application/http" }, { cteHeaderName, Binary } });
+
+ request.Content = content;
+ var memStream = new MemoryStream();
+ await content.WriteToAsync(memStream, default);
+ memStream.Position = 0;
+ using var sr = new StreamReader(memStream, Encoding.UTF8);
+ string requestBody = sr.ReadToEnd();
+ Console.WriteLine(requestBody);
+
+
+ Assert.That(requestBody, Is.EqualTo($"--batch_{batchGuid}\r\n" +
+ $"{HttpHeader.Names.ContentType}: multipart/mixed; boundary=changeset_{changesetGuid}\r\n" +
+ $"\r\n" +
+ $"--changeset_{changesetGuid}\r\n" +
+ $"{HttpHeader.Names.ContentType}: application/http\r\n" +
+ $"{cteHeaderName}: {Binary}\r\n" +
+ $"\r\n" +
+ $"POST {postUri} HTTP/1.1\r\n" +
+ $"{HttpHeader.Names.Host}: {Host}\r\n" +
+ $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" +
+ $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" +
+ $"{DataServiceVersion}: {Three0}\r\n" +
+ $"{HttpHeader.Names.ContentLength}: 75\r\n" +
+ $"\r\n" +
+ $"{post1Body}\r\n" +
+ $"--changeset_{changesetGuid}\r\n" +
+ $"{HttpHeader.Names.ContentType}: application/http\r\n" +
+ $"{cteHeaderName}: {Binary}\r\n" +
+ $"\r\n" +
+ $"POST {postUri} HTTP/1.1\r\n" +
+ $"{HttpHeader.Names.Host}: {Host}\r\n" +
+ $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" +
+ $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" +
+ $"{DataServiceVersion}: {Three0}\r\n" +
+ $"{HttpHeader.Names.ContentLength}: 75\r\n" +
+ $"\r\n" +
+ $"{post2Body}\r\n" +
+ $"--changeset_{changesetGuid}\r\n" +
+ $"{HttpHeader.Names.ContentType}: application/http\r\n" +
+ $"{cteHeaderName}: {Binary}\r\n" +
+ $"\r\n" +
+ $"PATCH {mergeUri} HTTP/1.1\r\n" +
+ $"{HttpHeader.Names.Host}: {Host}\r\n" +
+ $"{HttpHeader.Names.ContentType}: {ApplicationJsonOdata}\r\n" +
+ $"{HttpHeader.Names.Accept}: {ApplicationJson}\r\n" +
+ $"{DataServiceVersion}: {Three0}\r\n" +
+ $"{HttpHeader.Names.ContentLength}: 82\r\n" +
+ $"\r\n" +
+ $"{patchBody}\r\n" +
+ $"--changeset_{changesetGuid}--\r\n" +
+ $"\r\n" +
+ $"--batch_{batchGuid}--\r\n" +
+ $""));
+ }
+
+ private static MemoryStream MakeStream(string text)
+ {
+ return new MemoryStream(Encoding.UTF8.GetBytes(text));
+ }
+
+ private static string GetString(byte[] buffer, int count)
+ {
+ return Encoding.ASCII.GetString(buffer, 0, count);
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs
index 1b8b8c51dc49..8e3cdc2e01e3 100644
--- a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs
+++ b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs
@@ -237,7 +237,7 @@ public async Task CanGetAndSetMethod(RequestMethod method, string expectedMethod
[TestCaseSource(nameof(AllHeadersWithValuesAndType))]
public async Task CanGetAndAddRequestHeaders(string headerName, string headerValue, bool contentHeader)
{
- StringValues httpHeaderValues;
+ StringValues httpHeaderValues;
using TestServer testServer = new TestServer(
context =>
diff --git a/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj b/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj
index e9e6e1224089..44dbd456d22a 100644
--- a/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj
+++ b/sdk/tables/Azure.Data.Tables/src/Azure.Data.Tables.csproj
@@ -8,15 +8,19 @@
$(RequiredTargetFrameworks)
$(NoWarn);CA1812;CS1591
+ true
+
+
+
@@ -25,11 +29,12 @@
+
+
-
diff --git a/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs b/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs
new file mode 100644
index 000000000000..7db3698c60cd
--- /dev/null
+++ b/sdk/tables/Azure.Data.Tables/src/MultipartContentExtensions.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#nullable disable
+
+using System;
+using Azure.Core;
+
+namespace Azure.Data.Tables
+{
+ internal static class MultipartContentExtensions
+ {
+ internal static MultipartContent AddChangeset(this MultipartContent batch)
+ {
+ var changeset = new MultipartContent("mixed", $"changeset_{Guid.NewGuid()}");
+ batch.Add(changeset, changeset._headers);
+ return changeset;
+ }
+ }
+}
diff --git a/sdk/tables/Azure.Data.Tables/src/TableClient.cs b/sdk/tables/Azure.Data.Tables/src/TableClient.cs
index 269a5940d071..042185b0d079 100644
--- a/sdk/tables/Azure.Data.Tables/src/TableClient.cs
+++ b/sdk/tables/Azure.Data.Tables/src/TableClient.cs
@@ -903,6 +903,76 @@ public virtual Response SetAccessPolicy(IEnumerable tableAcl,
}
}
+ ///
+ /// Placeholder for batch operations. This is just being used for testing.
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal virtual async Task>> BatchTestAsync(IEnumerable entities, CancellationToken cancellationToken = default) where T : class, ITableEntity, new()
+ {
+ using DiagnosticScope scope = _diagnostics.CreateScope($"{nameof(TableClient)}.{nameof(BatchTest)}");
+ scope.Start();
+ try
+ {
+ var batch = TableRestClient.CreateBatchContent();
+ var changeset = batch.AddChangeset();
+ foreach (var entity in entities)
+ {
+ _tableOperations.AddInsertEntityRequest(
+ changeset,
+ _table,
+ null,
+ null,
+ null,
+ tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
+ queryOptions: new QueryOptions() { Format = _format });
+ }
+ return await _tableOperations.SendBatchRequestAsync(_tableOperations.CreateBatchRequest(batch, null, null), cancellationToken).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ scope.Failed(ex);
+ throw;
+ }
+ }
+
+ ///
+ /// Placeholder for batch operations. This is just being used for testing.
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal virtual Response> BatchTest(IEnumerable entities, CancellationToken cancellationToken = default) where T : class, ITableEntity, new()
+ {
+ using DiagnosticScope scope = _diagnostics.CreateScope($"{nameof(TableClient)}.{nameof(BatchTest)}");
+ scope.Start();
+ try
+ {
+ var batch = TableRestClient.CreateBatchContent();
+ var changeset = batch.AddChangeset();
+ foreach (var entity in entities)
+ {
+ _tableOperations.AddInsertEntityRequest(
+ changeset,
+ _table,
+ null,
+ null,
+ null,
+ tableEntityProperties: entity.ToOdataAnnotatedDictionary(),
+ queryOptions: new QueryOptions() { Format = _format });
+ }
+ return _tableOperations.SendBatchRequest(_tableOperations.CreateBatchRequest(batch, null, null), cancellationToken);
+ }
+ catch (Exception ex)
+ {
+ scope.Failed(ex);
+ throw;
+ }
+ }
+
///
/// Creates an Odata filter query string from the provided expression.
///
diff --git a/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs b/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs
new file mode 100644
index 000000000000..a484f1abcc47
--- /dev/null
+++ b/sdk/tables/Azure.Data.Tables/src/TableRestClient.cs
@@ -0,0 +1,117 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#nullable disable
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core;
+using Azure.Core.Pipeline;
+using Azure.Data.Tables.Models;
+
+namespace Azure.Data.Tables
+{
+ internal partial class TableRestClient
+ {
+ private const string CteHeaderName = "Content-Transfer-Encoding";
+ private const string Binary = "binary";
+ private const string ApplicationHttp = "application/http";
+
+ internal HttpMessage CreateBatchRequest(MultipartContent content, string requestId, ResponseFormat? responsePreference)
+ {
+ var message = _pipeline.CreateMessage();
+ var request = message.Request;
+ request.Method = RequestMethod.Post;
+ var uri = new RawRequestUriBuilder();
+ uri.AppendRaw(url, false);
+ uri.AppendPath("/$batch", false);
+ request.Uri = uri;
+ request.Headers.Add("x-ms-version", version);
+ if (requestId != null)
+ {
+ request.Headers.Add("x-ms-client-request-id", requestId);
+ }
+ request.Headers.Add("DataServiceVersion", "3.0");
+ if (responsePreference != null)
+ {
+ request.Headers.Add("Prefer", responsePreference.Value.ToString());
+ }
+ //request.Headers.Add("Accept", "application/json");
+ request.Content = content;
+ content.ApplyToRequest(request);
+ return message;
+ }
+
+ internal static MultipartContent CreateBatchContent()
+ {
+ return new MultipartContent("mixed", $"batch_{Guid.NewGuid()}");
+ }
+
+ internal void AddInsertEntityRequest(MultipartContent changeset, string table, int? timeout, string requestId, ResponseFormat? responsePreference, IDictionary tableEntityProperties, QueryOptions queryOptions)
+ {
+ var message = CreateInsertEntityRequest(table, timeout, requestId, responsePreference, tableEntityProperties, queryOptions);
+ changeset.Add(new RequestRequestContent(message.Request), new Dictionary { { HttpHeader.Names.ContentType, ApplicationHttp }, { CteHeaderName, Binary } });
+ }
+
+ /// Insert entity in a table.
+ ///
+ /// The cancellation token to use.
+ /// is null.
+ public async Task>> SendBatchRequestAsync(HttpMessage message, CancellationToken cancellationToken = default)
+ {
+ if (message == null)
+ {
+ throw new ArgumentNullException(nameof(message));
+ }
+
+ await _pipeline.SendAsync(message, cancellationToken).ConfigureAwait(false);
+ switch (message.Response.Status)
+ {
+ case 202:
+ {
+ var responses = await Multipart.ParseAsync(
+ message.Response.ContentStream,
+ message.Response.Headers.ContentType,
+ true,
+ cancellationToken).ConfigureAwait(false);
+
+ return Response.FromValue(responses.ToList(), message.Response);
+ }
+ default:
+ throw await _clientDiagnostics.CreateRequestFailedExceptionAsync(message.Response).ConfigureAwait(false);
+ }
+ }
+
+ /// Insert entity in a table.
+ ///
+ /// The cancellation token to use.
+ /// is null.
+ public Response> SendBatchRequest(HttpMessage message, CancellationToken cancellationToken = default)
+ {
+ if (message == null)
+ {
+ throw new ArgumentNullException(nameof(message));
+ }
+
+ _pipeline.Send(message, cancellationToken);
+ switch (message.Response.Status)
+ {
+ case 202:
+ {
+ var responses = Multipart.ParseAsync(
+ message.Response.ContentStream,
+ message.Response.Headers.ContentType,
+ false,
+ cancellationToken).EnsureCompleted();
+
+ return Response.FromValue(responses.ToList(), message.Response);
+ }
+ default:
+ throw _clientDiagnostics.CreateRequestFailedException(message.Response);
+ }
+ }
+ }
+}
diff --git a/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs b/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs
index d90486723ef8..b3a8ed4120aa 100644
--- a/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs
+++ b/sdk/tables/Azure.Data.Tables/tests/TableClientLiveTests.cs
@@ -879,7 +879,6 @@ public async Task GetAccessPoliciesReturnsPolicies()
await client.SetAccessPolicyAsync(tableAcl: policyToCreate);
-
// Get the created policy.
var policies = await client.GetAccessPolicyAsync();
@@ -910,5 +909,30 @@ public async Task GetEntityReturnsSingleEntity()
Assert.That(entityResults, Is.Not.Null, "The entity should not be null.");
}
+
+ ///
+ /// Validates the functionality of the TableClient.
+ ///
+ [Test]
+ [LiveOnly]
+ public async Task BatchInsert()
+ {
+ if (_endpointType == TableEndpointType.CosmosTable)
+ {
+ Assert.Ignore("https://github.com/Azure/azure-sdk-for-net/issues/14272");
+ }
+ var entitiesToCreate = CreateCustomTableEntities(PartitionKeyValue, 20);
+
+ // Create the new entities.
+
+ var responses = await client.BatchTestAsync(entitiesToCreate).ConfigureAwait(false);
+
+ foreach (var response in responses.Value)
+ {
+ Assert.That(response.Status, Is.EqualTo((int)HttpStatusCode.Created));
+ }
+ Assert.That(responses.Value.Count, Is.EqualTo(entitiesToCreate.Count));
+ }
}
+
}