diff --git a/sdk/core/Azure.Core/src/Shared/ChangeTrackingDictionary.cs b/sdk/core/Azure.Core/src/Shared/ChangeTrackingDictionary.cs new file mode 100644 index 000000000000..e08339f311ed --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ChangeTrackingDictionary.cs @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; + +#nullable enable + +namespace Azure.Core +{ + internal class ChangeTrackingDictionary : IDictionary, IReadOnlyDictionary where TKey: notnull + { + private IDictionary? _innerDictionary; + + public ChangeTrackingDictionary() + { + } + + public ChangeTrackingDictionary(Optional> optionalDictionary) : this(optionalDictionary.Value) + { + } + + public ChangeTrackingDictionary(Optional> optionalDictionary) : this(optionalDictionary.Value) + { + } + + private ChangeTrackingDictionary(IDictionary dictionary) + { + if (dictionary == null) return; + + _innerDictionary = new Dictionary(dictionary); + } + + private ChangeTrackingDictionary(IReadOnlyDictionary dictionary) + { + if (dictionary == null) return; + + _innerDictionary = new Dictionary(); + foreach (KeyValuePair pair in dictionary) + { + _innerDictionary.Add(pair); + } + } + + public bool IsUndefined => _innerDictionary == null; + + public IEnumerator> GetEnumerator() + { + if (IsUndefined) + { + IEnumerator> GetEmptyEnumerator() + { + yield break; + } + return GetEmptyEnumerator(); + } + return EnsureDictionary().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public void Add(KeyValuePair item) + { + EnsureDictionary().Add(item); + } + + public void Clear() + { + EnsureDictionary().Clear(); + } + + public bool Contains(KeyValuePair item) + { + if (IsUndefined) + { + return false; + } + + return EnsureDictionary().Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + if (IsUndefined) + { + return; + } + + EnsureDictionary().CopyTo(array, arrayIndex); + } + + public bool Remove(KeyValuePair item) + { + if (IsUndefined) + { + return false; + } + + return EnsureDictionary().Remove(item); + } + + public int Count + { + get + { + if (IsUndefined) + { + return 0; + } + + return EnsureDictionary().Count; + } + } + + public bool IsReadOnly + { + get + { + if (IsUndefined) + { + return false; + } + return EnsureDictionary().IsReadOnly; + } + } + + public void Add(TKey key, TValue value) + { + EnsureDictionary().Add(key, value); + } + + public bool ContainsKey(TKey key) + { + if (IsUndefined) + { + return false; + } + + return EnsureDictionary().ContainsKey(key); + } + + public bool Remove(TKey key) + { + if (IsUndefined) + { + return false; + } + + return EnsureDictionary().Remove(key); + } + + public bool TryGetValue(TKey key, out TValue value) + { + if (IsUndefined) + { + value = default!; + return false; + } + return EnsureDictionary().TryGetValue(key, out value!); + } + + public TValue this[TKey key] + { + get + { + if (IsUndefined) + { + throw new KeyNotFoundException(nameof(key)); + } + + return EnsureDictionary()[key]; + } + set => EnsureDictionary()[key] = value; + } + + IEnumerable IReadOnlyDictionary.Keys => Keys; + + IEnumerable IReadOnlyDictionary.Values => Values; + + public ICollection Keys + { + get + { + if (IsUndefined) + { + return Array.Empty(); + } + + return EnsureDictionary().Keys; + } + } + + public ICollection Values + { + get + { + if (IsUndefined) + { + return Array.Empty(); + } + + return EnsureDictionary().Values; + } + } + + private IDictionary EnsureDictionary() + { + return _innerDictionary ??= new Dictionary(); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ChangeTrackingList.cs b/sdk/core/Azure.Core/src/Shared/ChangeTrackingList.cs new file mode 100644 index 000000000000..49e97efd18a4 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ChangeTrackingList.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +#nullable enable + +namespace Azure.Core +{ + internal class ChangeTrackingList: IList, IReadOnlyList + { + private IList? _innerList; + + public ChangeTrackingList() + { + } + + public ChangeTrackingList(Optional> optionalList) : this(optionalList.Value) + { + } + + public ChangeTrackingList(Optional> optionalList) : this(optionalList.Value) + { + } + + private ChangeTrackingList(IEnumerable innerList) + { + if (innerList == null) + { + return; + } + + _innerList = innerList.ToList(); + } + + private ChangeTrackingList(IList innerList) + { + if (innerList == null) + { + return; + } + + _innerList = innerList; + } + + public bool IsUndefined => _innerList == null; + + public void Reset() + { + _innerList = null; + } + + public IEnumerator GetEnumerator() + { + if (IsUndefined) + { + IEnumerator EnumerateEmpty() + { + yield break; + } + + return EnumerateEmpty(); + } + return EnsureList().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public void Add(T item) + { + EnsureList().Add(item); + } + + public void Clear() + { + EnsureList().Clear(); + } + + public bool Contains(T item) + { + if (IsUndefined) + { + return false; + } + + return EnsureList().Contains(item); + } + + public void CopyTo(T[] array, int arrayIndex) + { + if (IsUndefined) + { + return; + } + + EnsureList().CopyTo(array, arrayIndex); + } + + public bool Remove(T item) + { + if (IsUndefined) + { + return false; + } + + return EnsureList().Remove(item); + } + + public int Count + { + get + { + if (IsUndefined) + { + return 0; + } + return EnsureList().Count; + } + } + + public bool IsReadOnly + { + get + { + if (IsUndefined) + { + return false; + } + + return EnsureList().IsReadOnly; + } + } + + public int IndexOf(T item) + { + if (IsUndefined) + { + return -1; + } + + return EnsureList().IndexOf(item); + } + + public void Insert(int index, T item) + { + EnsureList().Insert(index, item); + } + + public void RemoveAt(int index) + { + if (IsUndefined) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + EnsureList().RemoveAt(index); + } + + public T this[int index] + { + get + { + if (IsUndefined) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + return EnsureList()[index]; + } + set + { + if (IsUndefined) + { + throw new ArgumentOutOfRangeException(nameof(index)); + } + + EnsureList()[index] = value; + } + } + + private IList EnsureList() + { + return _innerList ??= new List(); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ErrorResponse.cs b/sdk/core/Azure.Core/src/Shared/ErrorResponse.cs new file mode 100644 index 000000000000..6fad0180f596 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ErrorResponse.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ + internal class ErrorResponse : Response + { + private readonly Response _response; + private readonly RequestFailedException _exception; + + public ErrorResponse(Response response, RequestFailedException exception) + { + _response = response; + _exception = exception; + } + + public override T Value { get => throw _exception; } + + public override Response GetRawResponse() => _response; + } +} diff --git a/sdk/core/Azure.Core/src/Shared/FormUrlEncodedContent.cs b/sdk/core/Azure.Core/src/Shared/FormUrlEncodedContent.cs new file mode 100644 index 000000000000..91490f81a119 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/FormUrlEncodedContent.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; + +namespace Azure.Core +{ + internal class FormUrlEncodedContent : RequestContent + { + private List> _values = new List>(); + private Encoding Latin1 = Encoding.GetEncoding("iso-8859-1"); + private byte[] _bytes = Array.Empty(); + + public void Add (string parameter, string value) + { + _values.Add(new KeyValuePair (parameter, value)); + } + + private void BuildIfNeeded () + { + if (_bytes.Length == 0) + { + _bytes = GetContentByteArray(_values); + _values.Clear(); + } + } + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellation) + { + BuildIfNeeded (); +#if NET6_0_OR_GREATER + await stream.WriteAsync(_bytes.AsMemory(), cancellation).ConfigureAwait(false); +#else + await stream.WriteAsync(_bytes, 0, _bytes.Length, cancellation).ConfigureAwait(false); +#endif + } + + public override void WriteTo(Stream stream, CancellationToken cancellation) + { + BuildIfNeeded (); +#if NET6_0_OR_GREATER + stream.Write(_bytes.AsSpan()); +#else + stream.Write(_bytes, 0, _bytes.Length); +#endif + } + + public override bool TryComputeLength(out long length) + { + BuildIfNeeded (); + length = _bytes.Length; + return true; + } + + public override void Dispose() + { + } + + // Taken with love from https://github.com/dotnet/runtime/blob/master/src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs#L21-L53 + private byte[] GetContentByteArray(IEnumerable> nameValueCollection) + { + if (nameValueCollection == null) + { + throw new ArgumentNullException(nameof(nameValueCollection)); + } + + // Encode and concatenate data + StringBuilder builder = new StringBuilder(); + foreach (KeyValuePair pair in nameValueCollection) + { + if (builder.Length > 0) + { + builder.Append('&'); + } + + builder.Append(Encode(pair.Key)); + builder.Append('='); + builder.Append(Encode(pair.Value)); + } + + return Latin1.GetBytes(builder.ToString()); + } + + private static string Encode(string data) + { + if (string.IsNullOrEmpty(data)) + { + return string.Empty; + } + // Escape spaces as '+'. + return Uri.EscapeDataString(data).Replace("%20", "+"); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/HttpPipelineExtensions.cs b/sdk/core/Azure.Core/src/Shared/HttpPipelineExtensions.cs new file mode 100644 index 000000000000..1be1252193dc --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/HttpPipelineExtensions.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; + +namespace Azure.Core +{ + internal static class HttpPipelineExtensions + { + public static async ValueTask ProcessMessageAsync(this HttpPipeline pipeline, HttpMessage message, RequestContext? requestContext, CancellationToken cancellationToken = default) + { + var (userCt, statusOption) = ApplyRequestContext(requestContext); + if (!userCt.CanBeCanceled || !cancellationToken.CanBeCanceled) + { + await pipeline.SendAsync(message, cancellationToken.CanBeCanceled ? cancellationToken : userCt).ConfigureAwait(false); + } + else + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(userCt, cancellationToken); + await pipeline.SendAsync(message, cts.Token).ConfigureAwait(false); + } + + if (!message.Response.IsError || statusOption == ErrorOptions.NoThrow) + { + return message.Response; + } + + throw new RequestFailedException(message.Response); + } + + public static Response ProcessMessage(this HttpPipeline pipeline, HttpMessage message, RequestContext? requestContext, CancellationToken cancellationToken = default) + { + var (userCt, statusOption) = ApplyRequestContext(requestContext); + if (!userCt.CanBeCanceled || !cancellationToken.CanBeCanceled) + { + pipeline.Send(message, cancellationToken.CanBeCanceled ? cancellationToken : userCt); + } + else + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(userCt, cancellationToken); + pipeline.Send(message, cts.Token); + } + + if (!message.Response.IsError || statusOption == ErrorOptions.NoThrow) + { + return message.Response; + } + + throw new RequestFailedException(message.Response); + } + + public static async ValueTask> ProcessHeadAsBoolMessageAsync(this HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, RequestContext? requestContext) + { + var response = await pipeline.ProcessMessageAsync(message, requestContext).ConfigureAwait(false); + switch (response.Status) + { + case >= 200 and < 300: + return Response.FromValue(true, response); + case >= 400 and < 500: + return Response.FromValue(false, response); + default: + return new ErrorResponse(response, new RequestFailedException(response)); + } + } + + public static Response ProcessHeadAsBoolMessage(this HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, RequestContext? requestContext) + { + var response = pipeline.ProcessMessage(message, requestContext); + switch (response.Status) + { + case >= 200 and < 300: + return Response.FromValue(true, response); + case >= 400 and < 500: + return Response.FromValue(false, response); + default: + return new ErrorResponse(response, new RequestFailedException(response)); + } + } + + private static (CancellationToken CancellationToken, ErrorOptions ErrorOptions) ApplyRequestContext(RequestContext? requestContext) + { + if (requestContext == null) + { + return (CancellationToken.None, ErrorOptions.Default); + } + + return (requestContext.CancellationToken, requestContext.ErrorOptions); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/IOperationSource.cs b/sdk/core/Azure.Core/src/Shared/IOperationSource.cs new file mode 100644 index 000000000000..1be2f9b733d0 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/IOperationSource.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Core +{ + internal interface IOperationSource + { + T CreateResult(Response response, CancellationToken cancellationToken); + ValueTask CreateResultAsync(Response response, CancellationToken cancellationToken); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/IUtf8JsonSerializable.cs b/sdk/core/Azure.Core/src/Shared/IUtf8JsonSerializable.cs new file mode 100644 index 000000000000..5653e4609313 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/IUtf8JsonSerializable.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Text.Json; + +namespace Azure.Core +{ + internal interface IUtf8JsonSerializable + { + void Write(Utf8JsonWriter writer); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/IXmlSerializable.cs b/sdk/core/Azure.Core/src/Shared/IXmlSerializable.cs new file mode 100644 index 000000000000..343b127384d2 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/IXmlSerializable.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Xml; + +namespace Azure.Core +{ + internal interface IXmlSerializable + { + void Write(XmlWriter writer, string? nameHint); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/JsonElementExtensions.cs b/sdk/core/Azure.Core/src/Shared/JsonElementExtensions.cs new file mode 100644 index 000000000000..fea6b2560c4f --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/JsonElementExtensions.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Text.Json; +using System.Xml; + +namespace Azure.Core +{ + internal static class JsonElementExtensions + { + public static object? GetObject(in this JsonElement element) + { + switch (element.ValueKind) + { + case JsonValueKind.String: + return element.GetString(); + case JsonValueKind.Number: + if (element.TryGetInt32(out int intValue)) + { + return intValue; + } + if (element.TryGetInt64(out long longValue)) + { + return longValue; + } + return element.GetDouble(); + case JsonValueKind.True: + return true; + case JsonValueKind.False: + return false; + case JsonValueKind.Undefined: + case JsonValueKind.Null: + return null; + case JsonValueKind.Object: + var dictionary = new Dictionary(); + foreach (JsonProperty jsonProperty in element.EnumerateObject()) + { + dictionary.Add(jsonProperty.Name, jsonProperty.Value.GetObject()); + } + return dictionary; + case JsonValueKind.Array: + var list = new List(); + foreach (JsonElement item in element.EnumerateArray()) + { + list.Add(item.GetObject()); + } + return list.ToArray(); + default: + throw new NotSupportedException("Not supported value kind " + element.ValueKind); + } + } + + public static byte[]? GetBytesFromBase64(in this JsonElement element, string format) + { + if (element.ValueKind== JsonValueKind.Null) + { + return null; + } + + return format switch + { + "U" => TypeFormatters.FromBase64UrlString(element.GetRequiredString()), + "D" => element.GetBytesFromBase64(), + _ => throw new ArgumentException($"Format is not supported: '{format}'", nameof(format)) + }; + } + + public static DateTimeOffset GetDateTimeOffset(in this JsonElement element, string format) => format switch + { + "U" when element.ValueKind == JsonValueKind.Number => DateTimeOffset.FromUnixTimeSeconds(element.GetInt64()), + // relying on the param check of the inner call to throw ArgumentNullException if GetString() returns null + _ => TypeFormatters.ParseDateTimeOffset(element.GetString()!, format) + }; + + public static TimeSpan GetTimeSpan(in this JsonElement element, string format) => + // relying on the param check of the inner call to throw ArgumentNullException if GetString() returns null + TypeFormatters.ParseTimeSpan(element.GetString()!, format); + + public static char GetChar(this in JsonElement element) + { + if (element.ValueKind == JsonValueKind.String) + { + var text = element.GetString(); + if (text == null || text.Length != 1) + { + throw new NotSupportedException($"Cannot convert \"{text}\" to a Char"); + } + return text[0]; + } + else + { + throw new NotSupportedException($"Cannot convert {element.ValueKind} to a Char"); + } + } + + [Conditional("DEBUG")] + public static void ThrowNonNullablePropertyIsNull(this JsonProperty property) + { + throw new JsonException($"A property '{property.Name}' defined as non-nullable but received as null from the service. " + + $"This exception only happens in DEBUG builds of the library and would be ignored in the release build"); + } + + public static string GetRequiredString(in this JsonElement element) + { + var value = element.GetString(); + if (value == null) + throw new InvalidOperationException($"The requested operation requires an element of type 'String', but the target element has type '{element.ValueKind}'."); + + return value; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/LowLevelPageableHelpers.cs b/sdk/core/Azure.Core/src/Shared/LowLevelPageableHelpers.cs new file mode 100644 index 000000000000..17f06dc282c9 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/LowLevelPageableHelpers.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ + internal static class LowLevelPageableHelpers + { + } +} diff --git a/sdk/core/Azure.Core/src/Shared/NextLinkOperationImplementation.cs b/sdk/core/Azure.Core/src/Shared/NextLinkOperationImplementation.cs new file mode 100644 index 000000000000..120c2d63a7d3 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/NextLinkOperationImplementation.cs @@ -0,0 +1,449 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Linq; +using System.IO; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; + +namespace Azure.Core +{ + internal class NextLinkOperationImplementation : IOperation + { + private const string ApiVersionParam = "api-version"; + private static readonly string[] FailureStates = { "failed", "canceled" }; + private static readonly string[] SuccessStates = { "succeeded" }; + + private readonly HeaderSource _headerSource; + private readonly bool _originalResponseHasLocation; + private readonly Uri _startRequestUri; + private readonly OperationFinalStateVia _finalStateVia; + private readonly RequestMethod _requestMethod; + private readonly HttpPipeline _pipeline; + private readonly string? _apiVersion; + + private string? _lastKnownLocation; + private string _nextRequestUri; + + public static IOperation Create( + HttpPipeline pipeline, + RequestMethod requestMethod, + Uri startRequestUri, + Response response, + OperationFinalStateVia finalStateVia, + bool skipApiVersionOverride = false, + string? apiVersionOverrideValue = null) + { + string? apiVersionStr = null; + if (apiVersionOverrideValue is not null) + { + apiVersionStr = apiVersionOverrideValue; + } + else + { + apiVersionStr = !skipApiVersionOverride && TryGetApiVersion(startRequestUri, out ReadOnlySpan apiVersion) ? apiVersion.ToString() : null; + } + var headerSource = GetHeaderSource(requestMethod, startRequestUri, response, apiVersionStr, out var nextRequestUri); + if (headerSource == HeaderSource.None && IsFinalState(response, headerSource, out var failureState, out _)) + { + return new CompletedOperation(failureState ?? GetOperationStateFromFinalResponse(requestMethod, response)); + } + + var originalResponseHasLocation = response.Headers.TryGetValue("Location", out var lastKnownLocation); + return new NextLinkOperationImplementation(pipeline, requestMethod, startRequestUri, nextRequestUri, headerSource, originalResponseHasLocation, lastKnownLocation, finalStateVia, apiVersionStr); + } + + public static IOperation Create( + IOperationSource operationSource, + HttpPipeline pipeline, + RequestMethod requestMethod, + Uri startRequestUri, + Response response, + OperationFinalStateVia finalStateVia, + bool skipApiVersionOverride = false, + string? apiVersionOverrideValue = null) + { + var operation = Create(pipeline, requestMethod, startRequestUri, response, finalStateVia, skipApiVersionOverride, apiVersionOverrideValue); + return new OperationToOperationOfT(operationSource, operation); + } + + private NextLinkOperationImplementation( + HttpPipeline pipeline, + RequestMethod requestMethod, + Uri startRequestUri, + string nextRequestUri, + HeaderSource headerSource, + bool originalResponseHasLocation, + string? lastKnownLocation, + OperationFinalStateVia finalStateVia, + string? apiVersion) + { + _requestMethod = requestMethod; + _headerSource = headerSource; + _startRequestUri = startRequestUri; + _nextRequestUri = nextRequestUri; + _originalResponseHasLocation = originalResponseHasLocation; + _lastKnownLocation = lastKnownLocation; + _finalStateVia = finalStateVia; + _pipeline = pipeline; + _apiVersion = apiVersion; + } + + public async ValueTask UpdateStateAsync(bool async, CancellationToken cancellationToken) + { + Response response = await GetResponseAsync(async, _nextRequestUri, cancellationToken).ConfigureAwait(false); + + var hasCompleted = IsFinalState(response, _headerSource, out var failureState, out var resourceLocation); + if (failureState != null) + { + return failureState.Value; + } + + if (hasCompleted) + { + string? finalUri = GetFinalUri(resourceLocation); + var finalResponse = finalUri != null + ? await GetResponseAsync(async, finalUri, cancellationToken).ConfigureAwait(false) + : response; + + return GetOperationStateFromFinalResponse(_requestMethod, finalResponse); + } + + UpdateNextRequestUri(response.Headers); + return OperationState.Pending(response); + } + + private static OperationState GetOperationStateFromFinalResponse(RequestMethod requestMethod, Response response) + { + switch (response.Status) + { + case 200: + case 201 when requestMethod == RequestMethod.Put: + case 204 when requestMethod != RequestMethod.Put && requestMethod != RequestMethod.Patch: + return OperationState.Success(response); + default: + return OperationState.Failure(response); + } + } + + private void UpdateNextRequestUri(ResponseHeaders headers) + { + var hasLocation = headers.TryGetValue("Location", out string? location); + if (hasLocation) + { + _lastKnownLocation = location; + } + + switch (_headerSource) + { + case HeaderSource.OperationLocation when headers.TryGetValue("Operation-Location", out string? operationLocation): + _nextRequestUri = AppendOrReplaceApiVersion(operationLocation, _apiVersion); + return; + case HeaderSource.AzureAsyncOperation when headers.TryGetValue("Azure-AsyncOperation", out string? azureAsyncOperation): + _nextRequestUri = AppendOrReplaceApiVersion(azureAsyncOperation, _apiVersion); + return; + case HeaderSource.Location when hasLocation: + _nextRequestUri = AppendOrReplaceApiVersion(location!, _apiVersion); + return; + } + } + + internal static string AppendOrReplaceApiVersion(string uri, string? apiVersion) + { + if (!string.IsNullOrEmpty(apiVersion)) + { + var uriSpan = uri.AsSpan(); + var apiVersionParamSpan = ApiVersionParam.AsSpan(); + var apiVersionIndex = uriSpan.IndexOf(apiVersionParamSpan); + if (apiVersionIndex == -1) + { + var concatSymbol = uriSpan.IndexOf('?') > -1 ? "&" : "?"; + return $"{uri}{concatSymbol}api-version={apiVersion}"; + } + else + { + var lengthToEndOfApiVersionParam = apiVersionIndex + ApiVersionParam.Length; + ReadOnlySpan remaining = uriSpan.Slice(lengthToEndOfApiVersionParam); + bool apiVersionHasEqualSign = false; + if (remaining.IndexOf('=') == 0) + { + remaining = remaining.Slice(1); + lengthToEndOfApiVersionParam += 1; + apiVersionHasEqualSign = true; + } + var indexOfFirstSignAfterApiVersion = remaining.IndexOf('&'); + ReadOnlySpan uriBeforeApiVersion = uriSpan.Slice(0, lengthToEndOfApiVersionParam); + if (indexOfFirstSignAfterApiVersion == -1) + { + return string.Concat(uriBeforeApiVersion.ToString(), apiVersionHasEqualSign ? string.Empty : "=", apiVersion); + } + else + { + ReadOnlySpan uriAfterApiVersion = uriSpan.Slice(indexOfFirstSignAfterApiVersion + lengthToEndOfApiVersionParam); + return string.Concat(uriBeforeApiVersion.ToString(), apiVersionHasEqualSign ? string.Empty : "=", apiVersion, uriAfterApiVersion.ToString()); + } + } + } + return uri; + } + + internal static bool TryGetApiVersion(Uri startRequestUri, out ReadOnlySpan apiVersion) + { + apiVersion = default; + ReadOnlySpan uriSpan = startRequestUri.Query.AsSpan(); + int startIndex = uriSpan.IndexOf(ApiVersionParam.AsSpan()); + if (startIndex == -1) + { + return false; + } + startIndex += ApiVersionParam.Length; + ReadOnlySpan remaining = uriSpan.Slice(startIndex); + if (remaining.IndexOf('=') == 0) + { + remaining = remaining.Slice(1); + startIndex += 1; + } + else + { + return false; + } + int endIndex = remaining.IndexOf('&'); + int length = endIndex == -1 ? uriSpan.Length - startIndex : endIndex; + apiVersion = uriSpan.Slice(startIndex, length); + return true; + } + + /// + /// This function is used to get the final request uri after the lro has completed. + /// + private string? GetFinalUri(string? resourceLocation) + { + // Set final uri as null if the response for initial request doesn't contain header "Operation-Location" or "Azure-AsyncOperation". + if (_headerSource is not (HeaderSource.OperationLocation or HeaderSource.AzureAsyncOperation)) + { + return null; + } + + // Set final uri as null if initial request is a delete method. + if (_requestMethod == RequestMethod.Delete) + { + return null; + } + + // Handle final-state-via options: https://github.com/Azure/autorest/blob/main/docs/extensions/readme.md#x-ms-long-running-operation-options + switch (_finalStateVia) + { + case OperationFinalStateVia.LocationOverride when _originalResponseHasLocation: + return _lastKnownLocation; + case OperationFinalStateVia.OperationLocation or OperationFinalStateVia.AzureAsyncOperation when _requestMethod == RequestMethod.Post: + return null; + case OperationFinalStateVia.OriginalUri: + return _startRequestUri.AbsoluteUri; + } + + // If response body contains resourceLocation, use it: https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#target-resource-location + if (resourceLocation != null) + { + return resourceLocation; + } + + // If initial request is PUT or PATCH, return initial request Uri + if (_requestMethod == RequestMethod.Put || _requestMethod == RequestMethod.Patch) + { + return _startRequestUri.AbsoluteUri; + } + + // If response for initial request contains header "Location", return last known location + if (_originalResponseHasLocation) + { + return _lastKnownLocation; + } + + return null; + } + + private async ValueTask GetResponseAsync(bool async, string uri, CancellationToken cancellationToken) + { + using HttpMessage message = CreateRequest(uri); + if (async) + { + await _pipeline.SendAsync(message, cancellationToken).ConfigureAwait(false); + } + else + { + _pipeline.Send(message, cancellationToken); + } + return message.Response; + } + + private HttpMessage CreateRequest(string uri) + { + HttpMessage message = _pipeline.CreateMessage(); + Request request = message.Request; + request.Method = RequestMethod.Get; + + if (Uri.TryCreate(uri, UriKind.Absolute, out var nextLink) && nextLink.Scheme != "file") + { + request.Uri.Reset(nextLink); + } + else + { + request.Uri.Reset(new Uri(_startRequestUri, uri)); + } + + return message; + } + + private static bool IsFinalState(Response response, HeaderSource headerSource, out OperationState? failureState, out string? resourceLocation) + { + failureState = null; + resourceLocation = null; + + if (headerSource == HeaderSource.Location) + { + return response.Status != 202; + } + + if (response.Status is >= 200 and <= 204) + { + if (response.ContentStream is {Length: > 0}) + { + try + { + using JsonDocument document = JsonDocument.Parse(response.ContentStream); + var root = document.RootElement; + switch (headerSource) + { + case HeaderSource.None when root.TryGetProperty("properties", out var properties) && properties.TryGetProperty("provisioningState", out JsonElement property): + case HeaderSource.OperationLocation when root.TryGetProperty("status", out property): + case HeaderSource.AzureAsyncOperation when root.TryGetProperty("status", out property): + var state = property.GetRequiredString().ToLowerInvariant(); + if (FailureStates.Contains(state)) + { + failureState = OperationState.Failure(response); + return true; + } + else if (!SuccessStates.Contains(state)) + { + return false; + } + else + { + if (headerSource is HeaderSource.OperationLocation or HeaderSource.AzureAsyncOperation && root.TryGetProperty("resourceLocation", out var resourceLocationProperty)) + { + resourceLocation = resourceLocationProperty.GetString(); + } + return true; + } + } + } + finally + { + // It is required to reset the position of the content after reading as this response may be used for deserialization. + response.ContentStream.Position = 0; + } + } + + // If headerSource is None and provisioningState was not found, it defaults to Succeeded. + if (headerSource == HeaderSource.None) + { + return true; + } + } + + failureState = OperationState.Failure(response); + return true; + } + + private static bool ShouldIgnoreHeader(RequestMethod method, Response response) + => method.Method == RequestMethod.Patch.Method && response.Status == 200; + + private static HeaderSource GetHeaderSource(RequestMethod requestMethod, Uri requestUri, Response response, string? apiVersion, out string nextRequestUri) + { + if (ShouldIgnoreHeader(requestMethod, response)) + { + nextRequestUri = requestUri.AbsoluteUri; + return HeaderSource.None; + } + + var headers = response.Headers; + if (headers.TryGetValue("Operation-Location", out var operationLocationUri)) + { + nextRequestUri = AppendOrReplaceApiVersion(operationLocationUri, apiVersion); + return HeaderSource.OperationLocation; + } + + if (headers.TryGetValue("Azure-AsyncOperation", out var azureAsyncOperationUri)) + { + nextRequestUri = AppendOrReplaceApiVersion(azureAsyncOperationUri, apiVersion); + return HeaderSource.AzureAsyncOperation; + } + + if (headers.TryGetValue("Location", out var locationUri)) + { + nextRequestUri = AppendOrReplaceApiVersion(locationUri, apiVersion); + return HeaderSource.Location; + } + + nextRequestUri = requestUri.AbsoluteUri; + return HeaderSource.None; + } + + private enum HeaderSource + { + None, + OperationLocation, + AzureAsyncOperation, + Location + } + + private class CompletedOperation : IOperation + { + private readonly OperationState _operationState; + + public CompletedOperation(OperationState operationState) + { + _operationState = operationState; + } + + public ValueTask UpdateStateAsync(bool async, CancellationToken cancellationToken) => new(_operationState); + } + + private sealed class OperationToOperationOfT : IOperation + { + private readonly IOperationSource _operationSource; + private readonly IOperation _operation; + + public OperationToOperationOfT(IOperationSource operationSource, IOperation operation) + { + _operationSource = operationSource; + _operation = operation; + } + + public async ValueTask> UpdateStateAsync(bool async, CancellationToken cancellationToken) + { + var state = await _operation.UpdateStateAsync(async, cancellationToken).ConfigureAwait(false); + if (state.HasSucceeded) + { + var result = async + ? await _operationSource.CreateResultAsync(state.RawResponse, cancellationToken).ConfigureAwait(false) + : _operationSource.CreateResult(state.RawResponse, cancellationToken); + + return OperationState.Success(state.RawResponse, result); + } + + if (state.HasCompleted) + { + return OperationState.Failure(state.RawResponse, state.OperationFailedException); + } + + return OperationState.Pending(state.RawResponse); + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/OperationFinalStateVia.cs b/sdk/core/Azure.Core/src/Shared/OperationFinalStateVia.cs new file mode 100644 index 000000000000..8ad2396db952 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/OperationFinalStateVia.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ + internal enum OperationFinalStateVia + { + AzureAsyncOperation, + Location, + OriginalUri, + OperationLocation, + LocationOverride, + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Optional.cs b/sdk/core/Azure.Core/src/Shared/Optional.cs new file mode 100644 index 000000000000..4e1dade9d9e0 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Optional.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; + +namespace Azure.Core +{ + internal static class Optional + { + public static bool IsCollectionDefined(IEnumerable collection) + { + return !(collection is ChangeTrackingList changeTrackingList && changeTrackingList.IsUndefined); + } + + public static bool IsCollectionDefined(IReadOnlyDictionary collection) + { + return !(collection is ChangeTrackingDictionary changeTrackingList && changeTrackingList.IsUndefined); + } + + public static bool IsCollectionDefined(IDictionary collection) + { + return !(collection is ChangeTrackingDictionary changeTrackingList && changeTrackingList.IsUndefined); + } + + public static bool IsDefined(T? value) where T: struct + { + return value.HasValue; + } + public static bool IsDefined(object value) + { + return value != null; + } + public static bool IsDefined(string value) + { + return value != null; + } + + public static bool IsDefined(JsonElement value) + { + return value.ValueKind != JsonValueKind.Undefined; + } + + public static IReadOnlyDictionary ToDictionary(Optional> optional) + { + if (optional.HasValue) + { + return optional.Value; + } + return new ChangeTrackingDictionary(optional); + } + + public static IDictionary ToDictionary(Optional> optional) + { + if (optional.HasValue) + { + return optional.Value; + } + return new ChangeTrackingDictionary(optional); + } + public static IReadOnlyList ToList(Optional> optional) + { + if (optional.HasValue) + { + return optional.Value; + } + return new ChangeTrackingList(optional); + } + + public static IList ToList(Optional> optional) + { + if (optional.HasValue) + { + return optional.Value; + } + return new ChangeTrackingList(optional); + } + + public static T? ToNullable(Optional optional) where T: struct + { + if (optional.HasValue) + { + return optional.Value; + } + return default; + } + + public static T? ToNullable(Optional optional) where T: struct + { + return optional.Value; + } + } + + internal readonly partial struct Optional + { + public Optional(T value) : this() + { + Value = value; + HasValue = true; + } + + public T Value { get; } + public bool HasValue { get; } + + public static implicit operator Optional(T value) => new Optional(value); + public static implicit operator T(Optional optional) => optional.Value; + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Page.cs b/sdk/core/Azure.Core/src/Shared/Page.cs new file mode 100644 index 000000000000..1438fdfd5f04 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Page.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Collections.Generic; +using System.Linq; + +namespace Azure.Core +{ + internal static class Page + { + public static Page FromValues(IEnumerable values, string continuationToken, Response response) => + Page.FromValues(values.ToList(), continuationToken, response); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/PageableHelpers.cs b/sdk/core/Azure.Core/src/Shared/PageableHelpers.cs new file mode 100644 index 000000000000..0e4d456aff04 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/PageableHelpers.cs @@ -0,0 +1,556 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; + +namespace Azure.Core +{ + internal static class PageableHelpers + { + private static readonly byte[] DefaultItemPropertyName = Encoding.UTF8.GetBytes("value"); + private static readonly byte[] DefaultNextLinkPropertyName = Encoding.UTF8.GetBytes("nextLink"); + + public static AsyncPageable CreateAsyncPageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func? Values, string? NextLink)> responseParser, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, RequestContext? requestContext = null) where T : notnull + { + return new AsyncPageableWrapper(new PageableImplementation(createFirstPageRequest, createNextPageRequest, responseParser, pipeline, clientDiagnostics, scopeName, null, requestContext)); + } + + public static AsyncPageable CreateAsyncPageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, CancellationToken cancellationToken) where T : notnull + { + return new AsyncPageableWrapper(new PageableImplementation(null, createFirstPageRequest, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, cancellationToken, null)); + } + + public static AsyncPageable CreateAsyncPageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, RequestContext? requestContext = null) where T : notnull + { + return new AsyncPageableWrapper(new PageableImplementation(null, createFirstPageRequest, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, requestContext?.CancellationToken, requestContext?.ErrorOptions)); + } + + public static AsyncPageable CreateAsyncPageable(Response initialResponse, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, CancellationToken cancellationToken) where T : notnull + { + return new AsyncPageableWrapper(new PageableImplementation(initialResponse, null, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, cancellationToken, null)); + } + + public static Pageable CreatePageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func? Values, string? NextLink)> responseParser, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, RequestContext? requestContext = null) where T : notnull + { + return new PageableWrapper(new PageableImplementation(createFirstPageRequest, createNextPageRequest, responseParser, pipeline, clientDiagnostics, scopeName, null, requestContext)); + } + + public static Pageable CreatePageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, CancellationToken cancellationToken) where T : notnull + { + return new PageableWrapper(new PageableImplementation(null, createFirstPageRequest, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, cancellationToken, null)); + } + + public static Pageable CreatePageable(Func? createFirstPageRequest, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, RequestContext? requestContext = null) where T : notnull + { + return new PageableWrapper(new PageableImplementation(null, createFirstPageRequest, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, requestContext?.CancellationToken, requestContext?.ErrorOptions)); + } + + public static Pageable CreatePageable(Response initialResponse, Func? createNextPageRequest, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, CancellationToken cancellationToken) where T : notnull + { + return new PageableWrapper(new PageableImplementation(initialResponse, null, createNextPageRequest, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, cancellationToken, null)); + } + + public static async ValueTask>> CreateAsyncPageable(WaitUntil waitUntil, HttpMessage message, Func? createNextPageMethod, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, OperationFinalStateVia finalStateVia, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, RequestContext? requestContext = null) where T : notnull + { + AsyncPageable ResultSelector(Response r) => new AsyncPageableWrapper(new PageableImplementation(r, null, createNextPageMethod, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, requestContext?.CancellationToken, requestContext?.ErrorOptions)); + + var response = await pipeline.ProcessMessageAsync(message, requestContext).ConfigureAwait(false); + var operation = new ProtocolOperation>(clientDiagnostics, pipeline, message.Request, response, finalStateVia, scopeName, ResultSelector); + if (waitUntil == WaitUntil.Completed) + { + await operation.WaitForCompletionAsync(requestContext?.CancellationToken ?? default).ConfigureAwait(false); + } + return operation; + } + + public static Operation> CreatePageable(WaitUntil waitUntil, HttpMessage message, Func? createNextPageMethod, Func valueFactory, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, OperationFinalStateVia finalStateVia, string scopeName, string? itemPropertyName, string? nextLinkPropertyName, RequestContext? requestContext = null) where T : notnull + { + Pageable ResultSelector(Response r) => new PageableWrapper(new PageableImplementation(r, null, createNextPageMethod, valueFactory, pipeline, clientDiagnostics, scopeName, itemPropertyName, nextLinkPropertyName, null, requestContext?.CancellationToken, requestContext?.ErrorOptions)); + + var response = pipeline.ProcessMessage(message, requestContext); + var operation = new ProtocolOperation>(clientDiagnostics, pipeline, message.Request, response, finalStateVia, scopeName, ResultSelector); + if (waitUntil == WaitUntil.Completed) + { + operation.WaitForCompletion(requestContext?.CancellationToken ?? default); + } + return operation; + } + + public static Pageable CreateEnumerable(Func> firstPageFunc, Func>? nextPageFunc, int? pageSize = default) where T : notnull + { + Func> first = (_, pageSizeHint) => firstPageFunc(pageSizeHint); + return new FuncPageable(first, nextPageFunc, pageSize); + } + + public static AsyncPageable CreateAsyncEnumerable(Func>> firstPageFunc, Func>>? nextPageFunc, int? pageSize = default) where T : notnull + { + Func>> first = (_, pageSizeHint) => firstPageFunc(pageSizeHint); + return new FuncAsyncPageable(first, nextPageFunc, pageSize); + } + + internal class FuncAsyncPageable : AsyncPageable where T : notnull + { + private readonly Func>> _firstPageFunc; + private readonly Func>>? _nextPageFunc; + private readonly int? _defaultPageSize; + + public FuncAsyncPageable(Func>> firstPageFunc, Func>>? nextPageFunc, int? defaultPageSize = default) + { + _firstPageFunc = firstPageFunc; + _nextPageFunc = nextPageFunc; + _defaultPageSize = defaultPageSize; + } + + public override async IAsyncEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default) + { + Func>>? pageFunc = string.IsNullOrEmpty(continuationToken) ? _firstPageFunc : _nextPageFunc; + + if (pageFunc == null) + { + yield break; + } + + int? pageSize = pageSizeHint ?? _defaultPageSize; + do + { + Page pageResponse = await pageFunc(continuationToken, pageSize).ConfigureAwait(false); + yield return pageResponse; + continuationToken = pageResponse.ContinuationToken; + pageFunc = _nextPageFunc; + } while (!string.IsNullOrEmpty(continuationToken) && pageFunc != null); + } + } + + internal class FuncPageable : Pageable where T : notnull + { + private readonly Func> _firstPageFunc; + private readonly Func>? _nextPageFunc; + private readonly int? _defaultPageSize; + + public FuncPageable(Func> firstPageFunc, Func>? nextPageFunc, int? defaultPageSize = default) + { + _firstPageFunc = firstPageFunc; + _nextPageFunc = nextPageFunc; + _defaultPageSize = defaultPageSize; + } + + public override IEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default) + { + Func>? pageFunc = string.IsNullOrEmpty(continuationToken) ? _firstPageFunc : _nextPageFunc; + + if (pageFunc == null) + { + yield break; + } + + int? pageSize = pageSizeHint ?? _defaultPageSize; + do + { + Page pageResponse = pageFunc(continuationToken, pageSize); + yield return pageResponse; + continuationToken = pageResponse.ContinuationToken; + pageFunc = _nextPageFunc; + } while (!string.IsNullOrEmpty(continuationToken) && pageFunc != null); + } + } + + private class AsyncPageableWrapper : AsyncPageable where T : notnull + { + private readonly PageableImplementation _implementation; + + public AsyncPageableWrapper(PageableImplementation implementation) + { + _implementation = implementation; + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => _implementation.GetAsyncEnumerator(cancellationToken); + public override IAsyncEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = null) => _implementation.AsPagesAsync(continuationToken, pageSizeHint, default); + } + + private class PageableWrapper : Pageable where T : notnull + { + private readonly PageableImplementation _implementation; + + public PageableWrapper(PageableImplementation implementation) + { + _implementation = implementation; + } + + public override IEnumerator GetEnumerator() => _implementation.GetEnumerator(); + public override IEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = null) => _implementation.AsPages(continuationToken, pageSizeHint); + } + + private class PageableImplementation + { + private readonly Response? _initialResponse; + private readonly Func? _createFirstPageRequest; + private readonly Func? _createNextPageRequest; + private readonly HttpPipeline _pipeline; + private readonly ClientDiagnostics _clientDiagnostics; + private readonly Func? _valueFactory; + private readonly Func? Values, string? NextLink)>? _responseParser; + private readonly string _scopeName; + private readonly byte[] _itemPropertyName; + private readonly byte[] _nextLinkPropertyName; + private readonly int? _defaultPageSize; + private readonly CancellationToken _cancellationToken; + private readonly ErrorOptions? _errorOptions; + + public PageableImplementation( + Response? initialResponse, + Func? createFirstPageRequest, + Func? createNextPageRequest, + Func valueFactory, + HttpPipeline pipeline, + ClientDiagnostics clientDiagnostics, + string scopeName, + string? itemPropertyName, + string? nextLinkPropertyName, + int? defaultPageSize, + CancellationToken? cancellationToken, + ErrorOptions? errorOptions) + { + _initialResponse = initialResponse; + _createFirstPageRequest = createFirstPageRequest; + _createNextPageRequest = createNextPageRequest; + _valueFactory = typeof(T) == typeof(BinaryData) ? null : valueFactory; + _responseParser = null; + _pipeline = pipeline; + _clientDiagnostics = clientDiagnostics; + _scopeName = scopeName; + _itemPropertyName = itemPropertyName != null ? Encoding.UTF8.GetBytes(itemPropertyName) : DefaultItemPropertyName; + _nextLinkPropertyName = nextLinkPropertyName != null ? Encoding.UTF8.GetBytes(nextLinkPropertyName) : DefaultNextLinkPropertyName; + _defaultPageSize = defaultPageSize; + _cancellationToken = cancellationToken ?? default; + _errorOptions = errorOptions ?? ErrorOptions.Default; + } + + public PageableImplementation(Func? createFirstPageRequest, Func? createNextPageRequest, Func? Values, string? NextLink)> responseParser, HttpPipeline pipeline, ClientDiagnostics clientDiagnostics, string scopeName, int? defaultPageSize, RequestContext? requestContext) + { + _createFirstPageRequest = createFirstPageRequest; + _createNextPageRequest = createNextPageRequest; + _valueFactory = null; + _responseParser = responseParser; + _pipeline = pipeline; + _clientDiagnostics = clientDiagnostics; + _scopeName = scopeName; + _itemPropertyName = Array.Empty(); + _nextLinkPropertyName = Array.Empty(); + _defaultPageSize = defaultPageSize; + _cancellationToken = requestContext?.CancellationToken ?? default; + _errorOptions = requestContext?.ErrorOptions ?? ErrorOptions.Default; + } + + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + string? nextLink = null; + do + { + var response = await GetNextResponseAsync(null, nextLink, cancellationToken).ConfigureAwait(false); + if (!TryGetItemsFromResponse(response, out nextLink, out var jsonArray, out var items)) + { + continue; + } + + if (_valueFactory != null) + { + foreach (var jsonItem in jsonArray) + { + yield return _valueFactory(jsonItem); + } + } + else + { + foreach (var item in items!) + { + yield return item; + } + } + } while (!string.IsNullOrEmpty(nextLink)); + } + + public async IAsyncEnumerable> AsPagesAsync(string? continuationToken, int? pageSizeHint, [EnumeratorCancellation] CancellationToken cancellationToken) + { + string? nextLink = continuationToken; + do + { + var response = await GetNextResponseAsync(pageSizeHint, nextLink, cancellationToken).ConfigureAwait(false); + if (response is null) + { + yield break; + } + yield return CreatePage(response, out nextLink); + } while (!string.IsNullOrEmpty(nextLink)); + } + + public IEnumerator GetEnumerator() + { + string? nextLink = null; + do + { + var response = GetNextResponse(null, nextLink); + if (!TryGetItemsFromResponse(response, out nextLink, out var jsonArray, out var items)) + { + continue; + } + + if (_valueFactory != null) + { + foreach (var jsonItem in jsonArray) + { + yield return _valueFactory(jsonItem); + } + } + else + { + foreach (var item in items!) + { + yield return item; + } + } + } while (!string.IsNullOrEmpty(nextLink)); + } + + public IEnumerable> AsPages(string? continuationToken, int? pageSizeHint) + { + string? nextLink = continuationToken; + do + { + var response = GetNextResponse(pageSizeHint, nextLink); + if (response is null) + { + yield break; + } + yield return CreatePage(response, out nextLink); + } while (!string.IsNullOrEmpty(nextLink)); + } + + private Response? GetNextResponse(int? pageSizeHint, string? nextLink) + { + var message = CreateMessage(pageSizeHint, nextLink, out var response); + if (message == null) + { + return response; + } + + using DiagnosticScope scope = _clientDiagnostics.CreateScope(_scopeName); + scope.Start(); + try + { + _pipeline.Send(message, _cancellationToken); + return GetResponse(message); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + private async ValueTask GetNextResponseAsync(int? pageSizeHint, string? nextLink, CancellationToken cancellationToken) + { + var message = CreateMessage(pageSizeHint, nextLink, out var response); + if (message == null) + { + return response; + } + + using DiagnosticScope scope = _clientDiagnostics.CreateScope(_scopeName); + scope.Start(); + try + { + if (cancellationToken.CanBeCanceled && _cancellationToken.CanBeCanceled) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken); + await _pipeline.SendAsync(message, cts.Token).ConfigureAwait(false); + } + else + { + var ct = cancellationToken.CanBeCanceled ? cancellationToken : _cancellationToken; + await _pipeline.SendAsync(message, ct).ConfigureAwait(false); + } + + return GetResponse(message); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + private HttpMessage? CreateMessage(int? pageSizeHint, string? nextLink, out Response? response) + { + if (!string.IsNullOrEmpty(nextLink)) + { + response = null; + return _createNextPageRequest?.Invoke(pageSizeHint ?? _defaultPageSize, nextLink!); + } + + if (_createFirstPageRequest == null) + { + response = _initialResponse; + return null; + } + + response = null; + return _createFirstPageRequest(pageSizeHint ?? _defaultPageSize); + } + + private Response GetResponse(HttpMessage message) + { + if (message.Response.IsError && _errorOptions != ErrorOptions.NoThrow) + { + throw new RequestFailedException(message.Response); + } + + return message.Response; + } + + // Tries to parse response either using default logic or by using custom parser + // Returns true when either jsonArrayEnumerator is not default or items is not null + private bool TryGetItemsFromResponse(Response? response, out string? nextLink, out JsonElement.ArrayEnumerator jsonArrayEnumerator, out List? items) + { + if (response is null) + { + nextLink = default; + jsonArrayEnumerator = default; + items = default; + return false; + } + + if (_valueFactory is not null) + { + items = default; + var document = response.ContentStream != null ? JsonDocument.Parse(response.ContentStream) : JsonDocument.Parse(response.Content); + if (_createNextPageRequest is null && _itemPropertyName.Length == 0) // Pageable is a simple collection of elements + { + nextLink = null; + jsonArrayEnumerator = document.RootElement.EnumerateArray(); + return true; + } + + nextLink = document.RootElement.TryGetProperty(_nextLinkPropertyName, out var nextLinkValue) ? nextLinkValue.GetString() : null; + if (document.RootElement.TryGetProperty(_itemPropertyName, out var itemsValue)) + { + jsonArrayEnumerator = itemsValue.EnumerateArray(); + return true; + } + + jsonArrayEnumerator = default; + return false; + } + + jsonArrayEnumerator = default; + // _responseParser will be null when T is BinaryData + var parsedResponse = _responseParser?.Invoke(response) ?? ParseResponseForBinaryData(response, _itemPropertyName, _nextLinkPropertyName); + items = parsedResponse.Values; + nextLink = parsedResponse.NextLink; + return items is not null; + } + + private Page CreatePage(Response response, out string? nextLink) + { + if (!TryGetItemsFromResponse(response, out nextLink, out var jsonArray, out var items)) + { + return Page.FromValues(Array.Empty(), nextLink, response); + } + + if (_valueFactory == null) + { + return Page.FromValues(items!, nextLink, response); + } + + var values = new List(); + foreach (var jsonItem in jsonArray) + { + values.Add(_valueFactory(jsonItem)); + } + + return Page.FromValues(values, nextLink, response); + } + } + + // This method is used to avoid calling _valueFactory for BinaryData cause it requires instantiation of strings. + // Remove it when `JsonElement` provides access to its UTF8 buffer. + // See also PageableMethodsWriterExtensions.GetValueFactory + private static (List? Values, string? NextLink) ParseResponseForBinaryData(Response response, byte[] itemPropertyName, byte[] nextLinkPropertyName) + { + var content = response.Content.ToMemory(); + var r = new Utf8JsonReader(content.Span); + + List? items = null; + string? nextLink = null; + + if (!r.Read() || r.TokenType != JsonTokenType.StartObject) + { + throw new InvalidOperationException("Expected response to be JSON object"); + } + + while (r.Read()) + { + switch (r.TokenType) + { + case JsonTokenType.PropertyName: + if (r.ValueTextEquals(nextLinkPropertyName)) + { + r.Read(); + nextLink = r.GetString(); + } + else if (r.ValueTextEquals(itemPropertyName)) + { + if (!r.Read() || r.TokenType != JsonTokenType.StartArray) + { + throw new InvalidOperationException($"Expected {Encoding.UTF8.GetString(itemPropertyName)} to be an array"); + } + + while (r.Read() && r.TokenType != JsonTokenType.EndArray) + { + var element = ReadBinaryData(ref r, content); + items ??= new List(); + items.Add((T)element); + } + } + else + { + r.Skip(); + } + break; + case JsonTokenType.EndObject: + break; + + default: + throw new Exception("Unexpected token"); + } + } + + return (items, nextLink); + + static object ReadBinaryData(ref Utf8JsonReader r, in ReadOnlyMemory content) + { + switch (r.TokenType) + { + case JsonTokenType.StartObject or JsonTokenType.StartArray: + int elementStart = (int)r.TokenStartIndex; + r.Skip(); + int elementEnd = (int)r.TokenStartIndex; + int length = elementEnd - elementStart + 1; + return new BinaryData(content.Slice(elementStart, length)); + case JsonTokenType.String: + return new BinaryData(content.Slice((int)r.TokenStartIndex, r.ValueSpan.Length + 2 /* open and closing quotes are not captured in the value span */)); + default: + return new BinaryData(content.Slice((int)r.TokenStartIndex, r.ValueSpan.Length)); + } + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ProtocolOperation.cs b/sdk/core/Azure.Core/src/Shared/ProtocolOperation.cs new file mode 100644 index 000000000000..5673b36fa4cc --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ProtocolOperation.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; + +namespace Azure.Core +{ + internal class ProtocolOperation : Operation, IOperation where T : notnull + { + private readonly Func _resultSelector; + private readonly OperationInternal _operation; + private readonly IOperation _nextLinkOperation; + + internal ProtocolOperation(ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, Request request, Response response, OperationFinalStateVia finalStateVia, string scopeName, Func resultSelector) + { + _resultSelector = resultSelector; + _nextLinkOperation = NextLinkOperationImplementation.Create(pipeline, request.Method, request.Uri.ToUri(), response, finalStateVia); + _operation = new OperationInternal(this, clientDiagnostics, response, scopeName); + } + +#pragma warning disable CA1822 + // This scenario is currently unsupported. + // See: https://github.com/Azure/autorest.csharp/issues/2158. + /// + public override string Id => throw new NotSupportedException(); +#pragma warning restore CA1822 + + /// + public override T Value => _operation.Value; + + /// + public override bool HasCompleted => _operation.HasCompleted; + + /// + public override bool HasValue => _operation.HasValue; + + /// + public override Response GetRawResponse() => _operation.RawResponse; + + /// + public override Response UpdateStatus(CancellationToken cancellationToken = default) => _operation.UpdateStatus(cancellationToken); + + /// + public override ValueTask UpdateStatusAsync(CancellationToken cancellationToken = default) => _operation.UpdateStatusAsync(cancellationToken); + + /// + public override ValueTask> WaitForCompletionAsync(CancellationToken cancellationToken = default) => _operation.WaitForCompletionAsync(cancellationToken); + + /// + public override ValueTask> WaitForCompletionAsync(TimeSpan pollingInterval, CancellationToken cancellationToken = default) => _operation.WaitForCompletionAsync(pollingInterval, cancellationToken); + + async ValueTask> IOperation.UpdateStateAsync(bool async, CancellationToken cancellationToken) + { + var state = await _nextLinkOperation.UpdateStateAsync(async, cancellationToken).ConfigureAwait(false); + if (state.HasSucceeded) + { + return OperationState.Success(state.RawResponse, _resultSelector(state.RawResponse)); + } + + if (state.HasCompleted) + { + return OperationState.Failure(state.RawResponse, state.OperationFailedException); + } + + return OperationState.Pending(state.RawResponse); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ProtocolOperationHelpers.cs b/sdk/core/Azure.Core/src/Shared/ProtocolOperationHelpers.cs new file mode 100644 index 000000000000..8e19fa6a25b6 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ProtocolOperationHelpers.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; + +namespace Azure.Core +{ + internal static class ProtocolOperationHelpers + { + public static Operation Convert(Operation operation, Func convertFunc, ClientDiagnostics diagnostics, string scopeName) + where TFrom : notnull + where TTo : notnull + => new ConvertOperation(operation, diagnostics, scopeName, convertFunc); + + public static ValueTask> ProcessMessageWithoutResponseValueAsync(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil) + => ProcessMessageAsync(pipeline, message, clientDiagnostics, scopeName, finalStateVia, requestContext, waitUntil, _ => new VoidValue()); + + public static Operation ProcessMessageWithoutResponseValue(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil) + => ProcessMessage(pipeline, message, clientDiagnostics, scopeName, finalStateVia, requestContext, waitUntil, _ => new VoidValue()); + + public static ValueTask> ProcessMessageAsync(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil) + => ProcessMessageAsync(pipeline, message, clientDiagnostics, scopeName, finalStateVia, requestContext, waitUntil, r => r.Content); + + public static Operation ProcessMessage(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil) + => ProcessMessage(pipeline, message, clientDiagnostics, scopeName, finalStateVia, requestContext, waitUntil, r => r.Content); + + public static async ValueTask> ProcessMessageAsync(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil, Func resultSelector) where T: notnull + { + var response = await pipeline.ProcessMessageAsync(message, requestContext).ConfigureAwait(false); + var operation = new ProtocolOperation(clientDiagnostics, pipeline, message.Request, response, finalStateVia, scopeName, resultSelector); + if (waitUntil == WaitUntil.Completed) + { + await operation.WaitForCompletionAsync(requestContext?.CancellationToken ?? default).ConfigureAwait(false); + } + return operation; + } + + public static Operation ProcessMessage(HttpPipeline pipeline, HttpMessage message, ClientDiagnostics clientDiagnostics, string scopeName, OperationFinalStateVia finalStateVia, RequestContext? requestContext, WaitUntil waitUntil, Func resultSelector) where T : notnull + { + var response = pipeline.ProcessMessage(message, requestContext); + var operation = new ProtocolOperation(clientDiagnostics, pipeline, message.Request, response, finalStateVia, scopeName, resultSelector); + if (waitUntil == WaitUntil.Completed) + { + operation.WaitForCompletion(requestContext?.CancellationToken ?? default); + } + return operation; + } + + private class ConvertOperation : Operation + where TFrom : notnull + where TTo : notnull + { + private readonly Operation _operation; + private readonly ClientDiagnostics _diagnostics; + private readonly string _waitForCompletionScopeName; + private readonly string _updateStatusScopeName; + private readonly Func _convertFunc; + private Response? _response; + + public ConvertOperation(Operation operation, ClientDiagnostics diagnostics, string operationName, Func convertFunc) + { + _operation = operation; + _diagnostics = diagnostics; + _waitForCompletionScopeName = $"{operationName}.{nameof(WaitForCompletion)}"; + _updateStatusScopeName = $"{operationName}.{nameof(UpdateStatus)}"; + _convertFunc = convertFunc; + _response = null; + } + + public override string Id => _operation.Id; + public override TTo Value => GetOrCreateValue(); + public override bool HasValue => _operation.HasValue; + public override bool HasCompleted => _operation.HasCompleted; + public override Response GetRawResponse() => _operation.GetRawResponse(); + + public override Response UpdateStatus(CancellationToken cancellationToken = default) + { + if (HasCompleted) + { + return GetRawResponse(); + } + + using var scope = CreateScope(_updateStatusScopeName); + try + { + return _operation.UpdateStatus(cancellationToken); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + public override async ValueTask UpdateStatusAsync(CancellationToken cancellationToken = default) + { + if (HasCompleted) + { + return GetRawResponse(); + } + + using var scope = CreateScope(_updateStatusScopeName); + try + { + return await _operation.UpdateStatusAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + public override Response WaitForCompletion(CancellationToken cancellationToken = default) + { + if (_response != null) + { + return _response; + } + + using var scope = CreateScope(_waitForCompletionScopeName); + try + { + var result = _operation.WaitForCompletion(cancellationToken); + return CreateResponseOfTTo(result); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + public override Response WaitForCompletion(TimeSpan pollingInterval, CancellationToken cancellationToken) + { + if (_response != null) + { + return _response; + } + + using var scope = CreateScope(_waitForCompletionScopeName); + try + { + var result = _operation.WaitForCompletion(pollingInterval, cancellationToken); + return CreateResponseOfTTo(result); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + public override async ValueTask> WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + if (_response != null) + { + return _response; + } + + using var scope = CreateScope(_waitForCompletionScopeName); + try + { + var result = await _operation.WaitForCompletionAsync(cancellationToken).ConfigureAwait(false); + return CreateResponseOfTTo(result); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + public override async ValueTask> WaitForCompletionAsync(TimeSpan pollingInterval, CancellationToken cancellationToken) + { + if (_response != null) + { + return _response; + } + + using var scope = CreateScope(_waitForCompletionScopeName); + try + { + var result = await _operation.WaitForCompletionAsync(pollingInterval, cancellationToken).ConfigureAwait(false); + return CreateResponseOfTTo(result); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + private TTo GetOrCreateValue() + => _response != null ? _response.Value : CreateResponseOfTTo(GetRawResponse()).Value; + + private Response CreateResponseOfTTo(Response responseTFrom) + => CreateResponseOfTTo(responseTFrom.GetRawResponse()); + + private Response CreateResponseOfTTo(Response rawResponse) + { + var value = _convertFunc(rawResponse); + var response = Response.FromValue(value, rawResponse); + Interlocked.CompareExchange(ref _response, response, null); + return _response; + } + + private DiagnosticScope CreateScope(string name) + { + var scope = _diagnostics.CreateScope(name); + scope.Start(); + return scope; + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/RawRequestUriBuilder.cs b/sdk/core/Azure.Core/src/Shared/RawRequestUriBuilder.cs new file mode 100644 index 000000000000..d9478175a728 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/RawRequestUriBuilder.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Globalization; +using System.IO; + +namespace Azure.Core +{ + internal class RawRequestUriBuilder: RequestUriBuilder + { + private const string SchemeSeparator = "://"; + private const char HostSeparator = '/'; + private const char PortSeparator = ':'; + private static readonly char[] HostOrPort = { HostSeparator, PortSeparator }; + private const char QueryBeginSeparator = '?'; + private const char QueryContinueSeparator = '&'; + private const char QueryValueSeparator = '='; + + private RawWritingPosition? _position; + + private static void GetQueryParts(ReadOnlySpan queryUnparsed, out ReadOnlySpan name, out ReadOnlySpan value) + { + int separatorIndex = queryUnparsed.IndexOf(QueryValueSeparator); + if (separatorIndex == -1) + { + name = queryUnparsed; + value = ReadOnlySpan.Empty; + } + else + { + name = queryUnparsed.Slice(0, separatorIndex); + value = queryUnparsed.Slice(separatorIndex + 1); + } + } + + public void AppendRaw(string value, bool escape) + { + AppendRaw(value.AsSpan(), escape); + } + + private void AppendRaw(ReadOnlySpan value, bool escape) + { + if (_position == null) + { + if (HasQuery) + { + _position = RawWritingPosition.Query; + } + else if (HasPath) + { + _position = RawWritingPosition.Path; + } + else if (!string.IsNullOrEmpty(Host)) + { + _position = RawWritingPosition.Host; + } + else + { + _position = RawWritingPosition.Scheme; + } + } + + while (!value.IsEmpty) + { + if (_position == RawWritingPosition.Scheme) + { + int separator = value.IndexOf(SchemeSeparator.AsSpan(), StringComparison.InvariantCultureIgnoreCase); + if (separator == -1) + { + Scheme += value.ToString(); + value = ReadOnlySpan.Empty; + } + else + { + Scheme += value.Slice(0, separator).ToString(); + // TODO: Find a better way to map schemes to default ports + Port = string.Equals(Scheme, "https", StringComparison.OrdinalIgnoreCase) ? 443 : 80; + value = value.Slice(separator + SchemeSeparator.Length); + _position = RawWritingPosition.Host; + } + } + else if (_position == RawWritingPosition.Host) + { + int separator = value.IndexOfAny(HostOrPort); + if (separator == -1) + { + if (!HasPath) + { + Host += value.ToString(); + value = ReadOnlySpan.Empty; + } + else + { + // All Host information must be written before Path information + // If Path already has information, we transition to writing Path + _position = RawWritingPosition.Path; + } + } + else + { + Host += value.Slice(0, separator).ToString(); + _position = value[separator] == HostSeparator ? RawWritingPosition.Path : RawWritingPosition.Port; + value = value.Slice(separator + 1); + } + } + else if (_position == RawWritingPosition.Port) + { + int separator = value.IndexOf(HostSeparator); + if (separator == -1) + { +#if NETCOREAPP2_1_OR_GREATER + Port = int.Parse(value, NumberStyles.Integer, CultureInfo.InvariantCulture); +#else + Port = int.Parse(value.ToString(), CultureInfo.InvariantCulture); +#endif + value = ReadOnlySpan.Empty; + } + else + { +#if NETCOREAPP2_1_OR_GREATER + Port = int.Parse(value.Slice(0, separator), NumberStyles.Integer, CultureInfo.InvariantCulture); +#else + Port = int.Parse(value.Slice(0, separator).ToString(), CultureInfo.InvariantCulture); +#endif + value = value.Slice(separator + 1); + } + // Port cannot be split (like Host), so always transition to Path when Port is parsed + _position = RawWritingPosition.Path; + } + else if (_position == RawWritingPosition.Path) + { + int separator = value.IndexOf(QueryBeginSeparator); + if (separator == -1) + { + AppendPath(value, escape); + value = ReadOnlySpan.Empty; + } + else + { + AppendPath(value.Slice(0, separator), escape); + value = value.Slice(separator + 1); + _position = RawWritingPosition.Query; + } + } + else if (_position == RawWritingPosition.Query) + { + int separator = value.IndexOf(QueryContinueSeparator); + if (separator == 0) + { + value = value.Slice(1); + } + else if (separator == -1) + { + GetQueryParts(value, out var queryName, out var queryValue); + AppendQuery(queryName, queryValue, escape); + value = ReadOnlySpan.Empty; + } + else + { + GetQueryParts(value.Slice(0, separator), out var queryName, out var queryValue); + AppendQuery(queryName, queryValue, escape); + value = value.Slice(separator + 1); + } + } + } + } + + private enum RawWritingPosition + { + Scheme, + Host, + Port, + Path, + Query + } + + public void AppendRawNextLink(string nextLink, bool escape) + { + // If it is an absolute link, we use the nextLink as the entire url + if (nextLink.StartsWith(Uri.UriSchemeHttp, StringComparison.InvariantCultureIgnoreCase)) + { + Reset(new Uri(nextLink)); + return; + } + + AppendRaw(nextLink, escape); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/RequestContentHelper.cs b/sdk/core/Azure.Core/src/Shared/RequestContentHelper.cs new file mode 100644 index 000000000000..2c4eb7961783 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/RequestContentHelper.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Text.Json; + +namespace Azure.Core +{ + internal static class RequestContentHelper + { + public static RequestContent FromEnumerable(IEnumerable enumerable) where T: notnull + { + var content = new Utf8JsonRequestContent(); + content.JsonWriter.WriteStartArray(); + foreach (var item in enumerable) + { + content.JsonWriter.WriteObjectValue(item); + } + content.JsonWriter.WriteEndArray(); + + return content; + } + + public static RequestContent FromEnumerable(IEnumerable enumerable) + { + var content = new Utf8JsonRequestContent(); + content.JsonWriter.WriteStartArray(); + foreach (var item in enumerable) + { + if (item == null) + { + content.JsonWriter.WriteNullValue(); + } + else + { +#if NET6_0_OR_GREATER + content.JsonWriter.WriteRawValue(item); +#else + JsonSerializer.Serialize(content.JsonWriter, JsonDocument.Parse(item.ToString()).RootElement); +#endif + } + } + content.JsonWriter.WriteEndArray(); + + return content; + } + + public static RequestContent FromDictionary(IDictionary dictionary) where T : notnull + { + var content = new Utf8JsonRequestContent(); + content.JsonWriter.WriteStartObject(); + foreach (var item in dictionary) + { + content.JsonWriter.WritePropertyName(item.Key); + content.JsonWriter.WriteObjectValue(item.Value); + } + content.JsonWriter.WriteEndObject(); + + return content; + } + + public static RequestContent FromDictionary(IDictionary dictionary) + { + var content = new Utf8JsonRequestContent(); + content.JsonWriter.WriteStartObject(); + foreach (var item in dictionary) + { + content.JsonWriter.WritePropertyName(item.Key); + + if (item.Value == null) + { + content.JsonWriter.WriteNullValue(); + } + else + { +#if NET6_0_OR_GREATER + content.JsonWriter.WriteRawValue(item.Value); +#else + JsonSerializer.Serialize(content.JsonWriter, JsonDocument.Parse(item.Value.ToString()).RootElement); +#endif + } + } + content.JsonWriter.WriteEndObject(); + + return content; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/RequestHeaderExtensions.cs b/sdk/core/Azure.Core/src/Shared/RequestHeaderExtensions.cs new file mode 100644 index 000000000000..012156c93687 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/RequestHeaderExtensions.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Xml; + +namespace Azure.Core +{ + internal static class RequestHeaderExtensions + { + public static void Add(this RequestHeaders headers, string name, bool value) + { + headers.Add(name, TypeFormatters.ToString(value)); + } + + public static void Add(this RequestHeaders headers, string name, float value) + { + headers.Add(name, value.ToString(TypeFormatters.DefaultNumberFormat, CultureInfo.InvariantCulture)); + } + + public static void Add(this RequestHeaders headers, string name, double value) + { + headers.Add(name, value.ToString(TypeFormatters.DefaultNumberFormat, CultureInfo.InvariantCulture)); + } + + public static void Add(this RequestHeaders headers, string name, int value) + { + headers.Add(name, value.ToString(TypeFormatters.DefaultNumberFormat, CultureInfo.InvariantCulture)); + } + + public static void Add(this RequestHeaders headers, string name, long value) + { + headers.Add(name, value.ToString(TypeFormatters.DefaultNumberFormat, CultureInfo.InvariantCulture)); + } + + public static void Add(this RequestHeaders headers, string name, DateTimeOffset value, string format) + { + headers.Add(name, TypeFormatters.ToString(value, format)); + } + + public static void Add(this RequestHeaders headers, string name, TimeSpan value, string format) + { + headers.Add(name, TypeFormatters.ToString(value, format)); + } + + public static void Add(this RequestHeaders headers, string name, Guid value) + { + headers.Add(name, value.ToString()); + } + + public static void Add(this RequestHeaders headers, string name, byte[] value, string format) + { + headers.Add(name, TypeFormatters.ToString(value, format)); + } + + public static void Add(this RequestHeaders headers, string name, BinaryData value, string format) + { + headers.Add(name, TypeFormatters.ToString(value.ToArray(), format)); + } + + public static void Add(this RequestHeaders headers, string prefix, IDictionary headersToAdd) + { + foreach (var header in headersToAdd) + { + headers.Add(prefix + header.Key, header.Value); + } + } + + public static void Add(this RequestHeaders headers, string name, ETag value) + { + headers.Add(name, value.ToString("H")); + } + + public static void Add(this RequestHeaders headers, MatchConditions conditions) + { + if (conditions.IfMatch != null) + { + headers.Add("If-Match", conditions.IfMatch.Value); + } + + if (conditions.IfNoneMatch != null) + { + headers.Add("If-None-Match", conditions.IfNoneMatch.Value); + } + } + + public static void Add(this RequestHeaders headers, RequestConditions conditions, string format) + { + if (conditions.IfMatch != null) + { + headers.Add("If-Match", conditions.IfMatch.Value); + } + + if (conditions.IfNoneMatch != null) + { + headers.Add("If-None-Match", conditions.IfNoneMatch.Value); + } + + if (conditions.IfModifiedSince != null) + { + headers.Add("If-Modified-Since", conditions.IfModifiedSince.Value, format); + } + + if (conditions.IfUnmodifiedSince != null) + { + headers.Add("If-Unmodified-Since", conditions.IfUnmodifiedSince.Value, format); + } + } + + public static void AddDelimited(this RequestHeaders headers, string name, IEnumerable value, string delimiter) + { + headers.Add(name, string.Join(delimiter, value)); + } + + public static void AddDelimited(this RequestHeaders headers, string name, IEnumerable value, string delimiter, string format) + { + var stringValues = value.Select(v => TypeFormatters.ConvertToString(v, format)); + headers.Add(name, string.Join(delimiter, stringValues)); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/RequestUriBuilderExtensions.cs b/sdk/core/Azure.Core/src/Shared/RequestUriBuilderExtensions.cs new file mode 100644 index 000000000000..2c85e815ea90 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/RequestUriBuilderExtensions.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Xml; + +namespace Azure.Core +{ + internal static class RequestUriBuilderExtensions + { + public static void AppendPath(this RequestUriBuilder builder, bool value, bool escape = false) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, float value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, double value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, int value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, byte[] value, string format, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, IEnumerable value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, DateTimeOffset value, string format, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, TimeSpan value, string format, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, Guid value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendPath(this RequestUriBuilder builder, long value, bool escape = true) + { + builder.AppendPath(TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, bool value, bool escape = false) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, float value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, DateTimeOffset value, string format, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, TimeSpan value, string format, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, double value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, decimal value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, int value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, long value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, TimeSpan value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, byte[] value, string format, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value, format), escape); + } + + public static void AppendQuery(this RequestUriBuilder builder, string name, Guid value, bool escape = true) + { + builder.AppendQuery(name, TypeFormatters.ConvertToString(value), escape); + } + + public static void AppendQueryDelimited(this RequestUriBuilder builder, string name, IEnumerable value, string delimiter, bool escape = true) + { + var stringValues = value.Select(v => TypeFormatters.ConvertToString(v)); + builder.AppendQuery(name, string.Join(delimiter, stringValues), escape); + } + + public static void AppendQueryDelimited(this RequestUriBuilder builder, string name, IEnumerable value, string delimiter, string format, bool escape = true) + { + var stringValues = value.Select(v => TypeFormatters.ConvertToString(v, format)); + builder.AppendQuery(name, string.Join(delimiter, stringValues), escape); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ResponseHeadersExtensions.cs b/sdk/core/Azure.Core/src/Shared/ResponseHeadersExtensions.cs new file mode 100644 index 000000000000..106a668961a8 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ResponseHeadersExtensions.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Xml; + +namespace Azure.Core +{ + internal static class ResponseHeadersExtensions + { + private static readonly string[] KnownFormats = + { + // "r", // RFC 1123, required output format but too strict for input + "ddd, d MMM yyyy H:m:s 'GMT'", // RFC 1123 (r, except it allows both 1 and 01 for date and time) + "ddd, d MMM yyyy H:m:s 'UTC'", // RFC 1123, UTC + "ddd, d MMM yyyy H:m:s", // RFC 1123, no zone - assume GMT + "d MMM yyyy H:m:s 'GMT'", // RFC 1123, no day-of-week + "d MMM yyyy H:m:s 'UTC'", // RFC 1123, UTC, no day-of-week + "d MMM yyyy H:m:s", // RFC 1123, no day-of-week, no zone + "ddd, d MMM yy H:m:s 'GMT'", // RFC 1123, short year + "ddd, d MMM yy H:m:s 'UTC'", // RFC 1123, UTC, short year + "ddd, d MMM yy H:m:s", // RFC 1123, short year, no zone + "d MMM yy H:m:s 'GMT'", // RFC 1123, no day-of-week, short year + "d MMM yy H:m:s 'UTC'", // RFC 1123, UTC, no day-of-week, short year + "d MMM yy H:m:s", // RFC 1123, no day-of-week, short year, no zone + + "dddd, d'-'MMM'-'yy H:m:s 'GMT'", // RFC 850 + "dddd, d'-'MMM'-'yy H:m:s 'UTC'", // RFC 850, UTC + "dddd, d'-'MMM'-'yy H:m:s zzz", // RFC 850, offset + "dddd, d'-'MMM'-'yy H:m:s", // RFC 850 no zone + "ddd MMM d H:m:s yyyy", // ANSI C's asctime() format + + "ddd, d MMM yyyy H:m:s zzz", // RFC 5322 + "ddd, d MMM yyyy H:m:s", // RFC 5322 no zone + "d MMM yyyy H:m:s zzz", // RFC 5322 no day-of-week + "d MMM yyyy H:m:s", // RFC 5322 no day-of-week, no zone + }; + + public static bool TryGetValue(this ResponseHeaders headers, string name, out byte[]? value) + { + if (headers.TryGetValue(name, out string? stringValue)) + { + value = Convert.FromBase64String(stringValue); + return true; + } + + value = null; + return false; + } + + public static bool TryGetValue(this ResponseHeaders headers, string name, out TimeSpan? value) + { + if (headers.TryGetValue(name, out string? stringValue)) + { + value = XmlConvert.ToTimeSpan(stringValue); + return true; + } + + value = null; + return false; + } + + public static bool TryGetValue(this ResponseHeaders headers, string name, out DateTimeOffset? value) + { + if (headers.TryGetValue(name, out string? stringValue)) + { + if (DateTimeOffset.TryParseExact(stringValue, "r", DateTimeFormatInfo.InvariantInfo, DateTimeStyles.None, out var dto) || + DateTimeOffset.TryParseExact(stringValue, KnownFormats, DateTimeFormatInfo.InvariantInfo, DateTimeStyles.AllowInnerWhite | DateTimeStyles.AssumeUniversal, out dto)) + { + value = dto; + } + else + { + value = TypeFormatters.ParseDateTimeOffset(stringValue, ""); + } + + return true; + } + + value = null; + return false; + } + + public static bool TryGetValue(this ResponseHeaders headers, string name, out T? value) where T : struct + { + if (headers.TryGetValue(name, out string? stringValue)) + { + value = (T)Convert.ChangeType(stringValue, typeof(T), CultureInfo.InvariantCulture); + return true; + } + + value = null; + return false; + } + + public static bool TryGetValue(this ResponseHeaders headers, string name, out T? value) where T : class + { + if (headers.TryGetValue(name, out string? stringValue)) + { + value = (T)Convert.ChangeType(stringValue, typeof(T), CultureInfo.InvariantCulture); + return true; + } + + value = null; + return false; + } + + public static bool TryGetValue(this ResponseHeaders headers, string prefix, out IDictionary value) + { + value = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (HttpHeader item in headers) + { + if (item.Name.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) + { + value.Add(item.Name.Substring(prefix.Length), item.Value); + } + } + + return true; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ResponseWithHeaders.cs b/sdk/core/Azure.Core/src/Shared/ResponseWithHeaders.cs new file mode 100644 index 000000000000..2a48fb0ebbfb --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ResponseWithHeaders.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ + internal static class ResponseWithHeaders + { + public static ResponseWithHeaders FromValue(T value, THeaders headers, Response rawResponse) + { + return new ResponseWithHeaders(value, headers, rawResponse); + } + + public static ResponseWithHeaders FromValue(THeaders headers, Response rawResponse) + { + return new ResponseWithHeaders(headers, rawResponse); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTHeaders.cs b/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTHeaders.cs new file mode 100644 index 000000000000..9ca15f42d695 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTHeaders.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ +#pragma warning disable SA1649 // File name should match first type name + internal class ResponseWithHeaders +#pragma warning restore SA1649 // File name should match first type name + { + private readonly Response _rawResponse; + + public ResponseWithHeaders(THeaders headers, Response rawResponse) + { + _rawResponse = rawResponse; + Headers = headers; + } + + public Response GetRawResponse() => _rawResponse; + + public THeaders Headers { get; } + + public static implicit operator Response(ResponseWithHeaders self) => self.GetRawResponse(); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTOfTHeaders.cs b/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTOfTHeaders.cs new file mode 100644 index 000000000000..c150be52b524 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/ResponseWithHeadersOfTOfTHeaders.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.Core +{ +#pragma warning disable SA1649 // File name should match first type name + internal class ResponseWithHeaders : Response +#pragma warning restore SA1649 // File name should match first type name + { + private readonly Response _rawResponse; + + public ResponseWithHeaders(T value, THeaders headers, Response rawResponse) + { + _rawResponse = rawResponse; + Value = value; + Headers = headers; + } + + public override Response GetRawResponse() => _rawResponse; + + public override T Value { get; } + + public THeaders Headers { get; } + + public static implicit operator Response(ResponseWithHeaders self) => self.GetRawResponse(); + } +} diff --git a/sdk/core/Azure.Core/src/Shared/StringRequestContent.cs b/sdk/core/Azure.Core/src/Shared/StringRequestContent.cs new file mode 100644 index 000000000000..58a40c466612 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/StringRequestContent.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Core +{ + internal class StringRequestContent : RequestContent + { + private readonly byte[] _bytes; + + public StringRequestContent(string value) + { + _bytes = Encoding.UTF8.GetBytes(value); + } + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellation) + { +#if NET6_0_OR_GREATER + await stream.WriteAsync(_bytes.AsMemory(), cancellation).ConfigureAwait(false); +#else + await stream.WriteAsync(_bytes, 0, _bytes.Length, cancellation).ConfigureAwait(false); +#endif + } + + public override void WriteTo(Stream stream, CancellationToken cancellation) + { +#if NET6_0_OR_GREATER + stream.Write(_bytes.AsSpan()); +#else + stream.Write(_bytes, 0, _bytes.Length); +#endif + } + + public override bool TryComputeLength(out long length) + { + length = _bytes.Length; + return true; + } + + public override void Dispose() + { + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/TypeFormatters.cs b/sdk/core/Azure.Core/src/Shared/TypeFormatters.cs new file mode 100644 index 000000000000..6bb51d611c1d --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/TypeFormatters.cs @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Xml; + +namespace Azure.Core +{ + internal class TypeFormatters + { + private const string RoundtripZFormat = "yyyy-MM-ddTHH:mm:ss.fffffffZ"; + public static string DefaultNumberFormat { get; } = "G"; + + public static string ToString(bool value) => value ? "true" : "false"; + + public static string ToString(DateTime value, string format) => value.Kind switch + { + DateTimeKind.Utc => ToString((DateTimeOffset)value, format), + _ => throw new NotSupportedException($"DateTime {value} has a Kind of {value.Kind}. Azure SDK requires it to be UTC. You can call DateTime.SpecifyKind to change Kind property value to DateTimeKind.Utc.") + }; + + public static string ToString(DateTimeOffset value, string format) => format switch + { + "D" => value.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture), + "U" => value.ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture), + "O" => value.ToUniversalTime().ToString(RoundtripZFormat, CultureInfo.InvariantCulture), + "o" => value.ToUniversalTime().ToString(RoundtripZFormat, CultureInfo.InvariantCulture), + "R" => value.ToString("r", CultureInfo.InvariantCulture), + _ => value.ToString(format, CultureInfo.InvariantCulture) + }; + + public static string ToString(TimeSpan value, string format) => format switch + { + "P" => XmlConvert.ToString(value), + _ => value.ToString(format, CultureInfo.InvariantCulture) + }; + + public static string ToString(byte[] value, string format) => format switch + { + "U" => ToBase64UrlString(value), + "D" => Convert.ToBase64String(value), + _ => throw new ArgumentException($"Format is not supported: '{format}'", nameof(format)) + }; + + public static string ToBase64UrlString(byte[] value) + { + var numWholeOrPartialInputBlocks = checked(value.Length + 2) / 3; + var size = checked(numWholeOrPartialInputBlocks * 4); + var output = new char[size]; + + var numBase64Chars = Convert.ToBase64CharArray(value, 0, value.Length, output, 0); + + // Fix up '+' -> '-' and '/' -> '_'. Drop padding characters. + int i = 0; + for (; i < numBase64Chars; i++) + { + var ch = output[i]; + if (ch == '+') + { + output[i] = '-'; + } + else if (ch == '/') + { + output[i] = '_'; + } + else if (ch == '=') + { + // We've reached a padding character; truncate the remainder. + break; + } + } + + return new string(output, 0, i); + } + + public static byte[] FromBase64UrlString(string value) + { + var paddingCharsToAdd = GetNumBase64PaddingCharsToAddForDecode(value.Length); + + var output = new char[value.Length + paddingCharsToAdd]; + + int i; + for (i = 0; i < value.Length; i++) + { + var ch = value[i]; + if (ch == '-') + { + output[i] = '+'; + } + else if (ch == '_') + { + output[i] = '/'; + } + else + { + output[i] = ch; + } + } + + for (; i < output.Length; i++) + { + output[i] = '='; + } + + return Convert.FromBase64CharArray(output, 0, output.Length); + } + + private static int GetNumBase64PaddingCharsToAddForDecode(int inputLength) + { + switch (inputLength % 4) + { + case 0: + return 0; + case 2: + return 2; + case 3: + return 1; + default: + throw new InvalidOperationException("Malformed input"); + } + } + + public static DateTimeOffset ParseDateTimeOffset(string value, string format) + { + return format switch + { + "U" => DateTimeOffset.FromUnixTimeSeconds(long.Parse(value, CultureInfo.InvariantCulture)), + _ => DateTimeOffset.Parse(value, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal) + }; + } + + public static TimeSpan ParseTimeSpan(string value, string format) => format switch + { + "P" => XmlConvert.ToTimeSpan(value), + _ => TimeSpan.ParseExact(value, format, CultureInfo.InvariantCulture) + }; + + public static string ConvertToString(object? value, string? format = null) + => value switch + { + null => "null", + string s => s, + bool b => ToString(b), + int or float or double or long or decimal => ((IFormattable)value).ToString(DefaultNumberFormat, CultureInfo.InvariantCulture), + byte[] b when format != null => ToString(b, format), + IEnumerable s => string.Join(",", s), + DateTimeOffset dateTime when format != null => ToString(dateTime, format), + TimeSpan timeSpan when format != null => ToString(timeSpan, format), + TimeSpan timeSpan => XmlConvert.ToString(timeSpan), + Guid guid => guid.ToString(), + BinaryData binaryData => TypeFormatters.ConvertToString(binaryData.ToArray(), format), + _ => value.ToString()! + }; + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Utf8JsonRequestContent.cs b/sdk/core/Azure.Core/src/Shared/Utf8JsonRequestContent.cs new file mode 100644 index 000000000000..401393083af4 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Utf8JsonRequestContent.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.IO; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Core +{ + internal class Utf8JsonRequestContent: RequestContent + { + private readonly MemoryStream _stream; + private readonly RequestContent _content; + + public Utf8JsonWriter JsonWriter { get; } + + public Utf8JsonRequestContent() + { + _stream = new MemoryStream(); + _content = Create(_stream); + JsonWriter = new Utf8JsonWriter(_stream); + } + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellation) + { + await JsonWriter.FlushAsync(cancellation).ConfigureAwait(false); + await _content.WriteToAsync(stream, cancellation).ConfigureAwait(false); + } + + public override void WriteTo(Stream stream, CancellationToken cancellation) + { + JsonWriter.Flush(); + _content.WriteTo(stream, cancellation); + } + + public override bool TryComputeLength(out long length) + { + length = JsonWriter.BytesCommitted + JsonWriter.BytesPending; + return true; + } + + public override void Dispose() + { + JsonWriter.Dispose(); + _content.Dispose(); + _stream.Dispose(); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/Utf8JsonWriterExtensions.cs b/sdk/core/Azure.Core/src/Shared/Utf8JsonWriterExtensions.cs new file mode 100644 index 000000000000..2ed696535efa --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/Utf8JsonWriterExtensions.cs @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text.Json; + +namespace Azure.Core +{ + internal static class Utf8JsonWriterExtensions + { + public static void WriteStringValue(this Utf8JsonWriter writer, DateTimeOffset value, string format) => + writer.WriteStringValue(TypeFormatters.ToString(value, format)); + + public static void WriteStringValue(this Utf8JsonWriter writer, DateTime value, string format) => + writer.WriteStringValue(TypeFormatters.ToString(value, format)); + + public static void WriteStringValue(this Utf8JsonWriter writer, TimeSpan value, string format) => + writer.WriteStringValue(TypeFormatters.ToString(value, format)); + + public static void WriteStringValue(this Utf8JsonWriter writer, char value) => + writer.WriteStringValue(value.ToString(CultureInfo.InvariantCulture)); + + public static void WriteNonEmptyArray(this Utf8JsonWriter writer, string name, IReadOnlyList values) + { + if (values.Any()) + { + writer.WriteStartArray(name); + foreach (var s in values) + { + writer.WriteStringValue(s); + } + + writer.WriteEndArray(); + } + } + + public static void WriteBase64StringValue(this Utf8JsonWriter writer, byte[] value, string format) + { + if (value == null) + { + writer.WriteNullValue(); + return; + } + + switch (format) + { + case "U": + writer.WriteStringValue(TypeFormatters.ToBase64UrlString(value)); + break; + case "D": + writer.WriteBase64StringValue(value); + break; + default: + throw new ArgumentException($"Format is not supported: '{format}'", nameof(format)); + } + } + + public static void WriteNumberValue(this Utf8JsonWriter writer, DateTimeOffset value, string format) + { + if (format != "U") throw new ArgumentOutOfRangeException(format, "Only 'U' format is supported when writing a DateTimeOffset as a Number."); + + writer.WriteNumberValue(value.ToUnixTimeSeconds()); + } + + public static void WriteObjectValue(this Utf8JsonWriter writer, object value) + { + switch (value) + { + case null: + writer.WriteNullValue(); + break; + case IUtf8JsonSerializable serializable: + serializable.Write(writer); + break; + case byte[] bytes: + writer.WriteBase64StringValue(bytes); + break; + case BinaryData bytes: + writer.WriteBase64StringValue(bytes); + break; + case System.Text.Json.JsonElement json: + json.WriteTo(writer); + break; + case int i: + writer.WriteNumberValue(i); + break; + case decimal d: + writer.WriteNumberValue(d); + break; + case double d: + if (double.IsNaN(d)) + { + writer.WriteStringValue("NaN"); + } + else + { + writer.WriteNumberValue(d); + } + break; + case float f: + writer.WriteNumberValue(f); + break; + case long l: + writer.WriteNumberValue(l); + break; + case string s: + writer.WriteStringValue(s); + break; + case bool b: + writer.WriteBooleanValue(b); + break; + case Guid g: + writer.WriteStringValue(g); + break; + case DateTimeOffset dateTimeOffset: + writer.WriteStringValue(dateTimeOffset, "O"); + break; + case DateTime dateTime: + writer.WriteStringValue(dateTime, "O"); + break; + case IEnumerable> enumerable: + writer.WriteStartObject(); + foreach (KeyValuePair pair in enumerable) + { + writer.WritePropertyName(pair.Key); + writer.WriteObjectValue(pair.Value); + } + writer.WriteEndObject(); + break; + case IEnumerable objectEnumerable: + writer.WriteStartArray(); + foreach (object item in objectEnumerable) + { + writer.WriteObjectValue(item); + } + writer.WriteEndArray(); + break; + case TimeSpan timeSpan: + writer.WriteStringValue(timeSpan, "P"); + break; + + default: + throw new NotSupportedException("Not supported type " + value.GetType()); + } + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/XElementExtensions.cs b/sdk/core/Azure.Core/src/Shared/XElementExtensions.cs new file mode 100644 index 000000000000..20feb9db64f7 --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/XElementExtensions.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Xml.Linq; + +namespace Azure.Core +{ + internal static class XElementExtensions + { + public static byte[] GetBytesFromBase64Value(this XElement element, string format) => format switch + { + "U" => TypeFormatters.FromBase64UrlString(element.Value), + "D" => Convert.FromBase64String(element.Value), + _ => throw new ArgumentException($"Format is not supported: '{format}'", nameof(format)) + }; + + public static DateTimeOffset GetDateTimeOffsetValue(this XElement element, string format) => format switch + { + "U" => DateTimeOffset.FromUnixTimeSeconds((long)element), + _ => TypeFormatters.ParseDateTimeOffset(element.Value, format) + }; + + public static TimeSpan GetTimeSpanValue(this XElement element, string format) => TypeFormatters.ParseTimeSpan(element.Value, format); + #pragma warning disable CA1801 //Parameter format of method GetObjectValue is never used. Remove the parameter or use it in the method body. + public static object GetObjectValue(this XElement element, string format) + #pragma warning restore + { + return element.Value; + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/XmlWriterContent.cs b/sdk/core/Azure.Core/src/Shared/XmlWriterContent.cs new file mode 100644 index 000000000000..145b557f336d --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/XmlWriterContent.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Xml; + +namespace Azure.Core +{ + internal class XmlWriterContent : RequestContent + { + private readonly MemoryStream _stream; + private readonly RequestContent _content; + + public XmlWriterContent() + { + _stream = new MemoryStream(); + _content = Create(_stream); + XmlWriter = new XmlTextWriter(_stream, Encoding.UTF8); + } + + public XmlWriter XmlWriter { get; } + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellation) + { + XmlWriter.Flush(); + await _content.WriteToAsync(stream, cancellation).ConfigureAwait(false); + } + + public override void WriteTo(Stream stream, CancellationToken cancellation) + { + XmlWriter.Flush(); + _content.WriteTo(stream, cancellation); + } + + public override bool TryComputeLength(out long length) + { + XmlWriter.Flush(); + length = _stream.Length; + return true; + } + + public override void Dispose() + { + XmlWriter.Dispose(); + _content.Dispose(); + _stream.Dispose(); + } + } +} diff --git a/sdk/core/Azure.Core/src/Shared/XmlWriterExtensions.cs b/sdk/core/Azure.Core/src/Shared/XmlWriterExtensions.cs new file mode 100644 index 000000000000..31f5996e608f --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/XmlWriterExtensions.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Xml; + +namespace Azure.Core +{ + internal static class XmlWriterExtensions + { + public static void WriteObjectValue(this XmlWriter writer, object value, string? nameHint) + { + switch (value) + { + case IXmlSerializable serializable: + serializable.Write(writer, nameHint); + return; + default: + throw new NotImplementedException(); + } + } + + public static void WriteValue(this XmlWriter writer, DateTimeOffset value, string format) => + writer.WriteValue(TypeFormatters.ToString(value, format)); + + public static void WriteValue(this XmlWriter writer, TimeSpan value, string format) => + writer.WriteValue(TypeFormatters.ToString(value, format)); + + public static void WriteValue(this XmlWriter writer, byte[] value, string format) + { + writer.WriteValue(TypeFormatters.ToString(value, format)); + } + } +} diff --git a/sdk/core/Azure.Core/tests/AppendOrReplaceApiVersionTests.cs b/sdk/core/Azure.Core/tests/AppendOrReplaceApiVersionTests.cs new file mode 100644 index 000000000000..f180d48d5bca --- /dev/null +++ b/sdk/core/Azure.Core/tests/AppendOrReplaceApiVersionTests.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text.Json; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class AppendOrReplaceApiVersionTests + { + [TestCase("/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW", + "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW")] + [TestCase("/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-10-30-PREVIEW", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW", + "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW")] + [TestCase("https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW")] + [TestCase("https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-10-30-PREVIEW", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW", + "https://management.azure.com/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/testRG-416/providers/Microsoft.ManagedIdentity/userAssignedIdentities/testRi-6086?api-version=2021-09-30-PREVIEW")] + [TestCase("/operations/2608e164-7bb4-4012-98f6-9a862b2218a6", + "https://qnamaker-resource-name.api.cognitiveservices.azure.com/qnamaker/v5.0-preview.1/knowledgebases/create", + "/operations/2608e164-7bb4-4012-98f6-9a862b2218a6")] + [TestCase("xyz.com?api-version=2021-10-01&x=1", + "https://abc.com?api-version=2021-11-01&a=1", + "xyz.com?api-version=2021-11-01&x=1")] + [TestCase("https://xyz.com?x=1&api-version=2021-10-01&y=2", + "https://abc.com?a=1&api-version=2021-11-01&b=2", + "https://xyz.com?x=1&api-version=2021-11-01&y=2")] + [TestCase("https://xyz.com?x=1&api-version=2021-10-01", + "https://abc.com?a=1&api-version=2021-11-01", + "https://xyz.com?x=1&api-version=2021-11-01")] + [TestCase("https://xyz.com?x=1&api-version=2021-10-01", + "https://abc.com?a=1", + "https://xyz.com?x=1&api-version=2021-10-01")] + [TestCase("https://xyz.com?x=1", + "https://abc.com", + "https://xyz.com?x=1")] + [TestCase("https://xyz.com?API-VERSION=2021-10-01", + "https://abc.com?api-version=2021-11-01", + "https://xyz.com?API-VERSION=2021-10-01&api-version=2021-11-01")] + [TestCase("https://xyz.com?x=1", + "https://abc.com?api-version", + "https://xyz.com?x=1")] + [TestCase("https://xyz.com?x=1", + "https://abc.com?api-version&x=1", + "https://xyz.com?x=1")] + [TestCase("https://xyz.com?api-version", + "https://abc.com?api-version=2021-11-01", + "https://xyz.com?api-version=2021-11-01")] + [TestCase("https://xyz.com?api-version&x=1", + "https://abc.com?api-version=2021-11-01", + "https://xyz.com?api-version=2021-11-01&x=1")] + public void TestAppendOrReplaceApiVersion(string uriToProcess, string startRequestUriStr, string expectedUri) + { + Uri startRequestUri = new Uri(startRequestUriStr); + NextLinkOperationImplementation.TryGetApiVersion(startRequestUri, out ReadOnlySpan apiVersion); + string resultUri = NextLinkOperationImplementation.AppendOrReplaceApiVersion(uriToProcess, apiVersion == null ? null : apiVersion.ToString()); + Assert.AreEqual(resultUri, expectedUri); + } + } +} diff --git a/sdk/core/Azure.Core/tests/ChangeTrackingDictionaryTest.cs b/sdk/core/Azure.Core/tests/ChangeTrackingDictionaryTest.cs new file mode 100644 index 000000000000..12b2323a6c04 --- /dev/null +++ b/sdk/core/Azure.Core/tests/ChangeTrackingDictionaryTest.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class ChangeTrackingDictionaryTest + { + [Test] + public void UndefinedByDefault() + { + var dictionary = new ChangeTrackingDictionary(); + Assert.True(dictionary.IsUndefined); + } + + [Test] + public void ReadOperationsDontChange() + { + var dictionary = new ChangeTrackingDictionary(); + _ = dictionary.Count; + _ = dictionary.IsReadOnly; + _ = dictionary.Keys; + _ = dictionary.Values; + _ = dictionary.Contains(new KeyValuePair("c", "d")); + _ = dictionary.ContainsKey("a"); + _ = dictionary.TryGetValue("a", out _); + _ = dictionary.Remove("a"); + + foreach (var kvp in dictionary) + { + } + + dictionary.CopyTo(new KeyValuePair[5], 0); + + Assert.Throws(() => _ = dictionary["a"]); + + Assert.True(dictionary.IsUndefined); + } + + [Test] + public void CanAddElement() + { + var dictionary = new ChangeTrackingDictionary(); + dictionary.Add("a", "b"); + + Assert.AreEqual("b", dictionary["a"]); + Assert.False(dictionary.IsUndefined); + } + + [Test] + public void RemoveElement() + { + var dictionary = new ChangeTrackingDictionary(); + dictionary.Add("a", "b"); + dictionary.Add(new KeyValuePair("c", "d")); + dictionary.Remove("a"); + dictionary.Remove(new KeyValuePair("c", "d")); + + Assert.AreEqual(0, dictionary.Count); + Assert.False(dictionary.IsUndefined); + } + + [Test] + public void ClearResetsUndefined() + { + var dictionary = new ChangeTrackingDictionary(); + dictionary.Clear(); + + Assert.AreEqual(0, dictionary.Count); + Assert.False(dictionary.IsUndefined); + } + } +} diff --git a/sdk/core/Azure.Core/tests/ChangeTrackingListTest.cs b/sdk/core/Azure.Core/tests/ChangeTrackingListTest.cs new file mode 100644 index 000000000000..ea2782d9c301 --- /dev/null +++ b/sdk/core/Azure.Core/tests/ChangeTrackingListTest.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Linq; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class ChangeTrackingListTest + { + [Test] + public void UndefinedByDefault() + { + var list = new ChangeTrackingList(); + Assert.True(list.IsUndefined); + } + + [Test] + public void ReadOperationsDontChange() + { + var list = new ChangeTrackingList(); + _ = list.Count; + _ = list.IsReadOnly; + _ = list.Contains("a"); + _ = list.IndexOf("a"); + _ = list.Remove("a"); + + foreach (var kvp in list) + { + } + + list.CopyTo(new string[5], 0); + + Assert.Throws(() => _ = list[0]); + + Assert.True(list.IsUndefined); + } + + [Test] + public void CanAddElement() + { + var list = new ChangeTrackingList(); + list.Add("a"); + + Assert.AreEqual("a", list[0]); + Assert.False(list.IsUndefined); + } + + [Test] + public void CanInsertElement() + { + var list = new ChangeTrackingList(); + list.Insert(0, "a"); + + Assert.AreEqual("a", list[0]); + Assert.False(list.IsUndefined); + } + + [Test] + public void ContainsWorks() + { + var list = new ChangeTrackingList(); + list.Add("a"); + + Assert.True(list.Contains("a")); + } + + [Test] + public void CanEnumerateItems() + { + var list = new ChangeTrackingList(); + list.Add("a"); + + Assert.AreEqual(new[] { "a" },list.ToArray()); + } + + [Test] + public void RemoveElement() + { + var list = new ChangeTrackingList(); + list.Add("a"); + list.Remove("a"); + + Assert.AreEqual(0, list.Count); + Assert.False(list.IsUndefined); + } + + [Test] + public void ClearResetsUndefined() + { + var list = new ChangeTrackingList(); + list.Clear(); + + Assert.AreEqual(0, list.Count); + Assert.False(list.IsUndefined); + } + } +} diff --git a/sdk/core/Azure.Core/tests/OptionalTests.cs b/sdk/core/Azure.Core/tests/OptionalTests.cs new file mode 100644 index 000000000000..5a04263e4573 --- /dev/null +++ b/sdk/core/Azure.Core/tests/OptionalTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Linq; +using System.Net; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class OptionalTests + { + [TestCase(int.MinValue)] + [TestCase(float.MinValue)] + [TestCase(false)] + [TestCase('a')] + public void DefaultPrimitiveValueType(T value) where T: struct + { + Optional optional = default; + Assert.False(Optional.ToNullable(optional).HasValue); + + optional = value; + Assert.True(Optional.ToNullable(optional).HasValue); + } + + [Test] + public void DefaultStructValueType() + { + Optional optional = default; + Assert.False(Optional.ToNullable(optional).HasValue); + + optional = DateTimeOffset.Now; + Assert.True(Optional.ToNullable(optional).HasValue); + } + + [Test] + public void DefaultEnumValueType() + { + Optional optional = default; + Assert.False(Optional.ToNullable(optional).HasValue); + + optional = HttpStatusCode.Accepted; + Assert.True(Optional.ToNullable(optional).HasValue); + } + + [TestCase("")] + public void DefaultReferenceType(T value) + { + Optional optional = default; + Assert.False(optional.HasValue); + + optional = value; + Assert.True(optional.HasValue); + } + } +} diff --git a/sdk/core/Azure.Core/tests/RawRequestUriBuilderTest.cs b/sdk/core/Azure.Core/tests/RawRequestUriBuilderTest.cs new file mode 100644 index 000000000000..d597d4853660 --- /dev/null +++ b/sdk/core/Azure.Core/tests/RawRequestUriBuilderTest.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class RawRequestUriBuilderTest + { + [Theory] + [TestCase("ht|tp|://", "localhost", null, "http://localhost/")] + [TestCase("http://|local|host", null, null, "http://localhost/")] + [TestCase("https://|local|host", null, null, "https://localhost/")] + [TestCase("http://|local|host", "override", null, "http://override/")] + [TestCase("http://|local|host|:12345", null, null, "http://localhost:12345/")] + [TestCase("http://|local|host:12345", null, null, "http://localhost:12345/")] + [TestCase("http://localhost/with|some|path", null, "/andMore", "http://localhost/withsomepath/andMore")] + [TestCase("http://localhost:12345/with|some|path", null, "/andMore", "http://localhost:12345/withsomepath/andMore")] + [TestCase("http://localhost|:12345/with|some|path", null, "/andMore", "http://localhost:12345/withsomepath/andMore")] + [TestCase("http://|localhost:12345|/with|some|path", null, "/andMore", "http://localhost:12345/withsomepath/andMore")] + [TestCase("http://|localhost:12345|withsomepath", null, "/andMore/", "http://localhost:12345/withsomepath/andMore/")] + [TestCase("http://|localhost:12345/|withsomepath", null, "/andMore/", "http://localhost:12345/withsomepath/andMore/")] + [TestCase("http://localhost|/more|/andMore|/evenMore", null, null, "http://localhost/more/andMore/evenMore")] + [TestCase("http://localhost|/more/andMore|/evenMore/", null, null, "http://localhost/more/andMore/evenMore/")] + [TestCase("http://|localhost:12345|/more/andMore|/evenMore/", null, null, "http://localhost:12345/more/andMore/evenMore/")] + [TestCase("http://localhost|/more|/andMore|?one=1", null, null, "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost|/more|/andMore?one=1", null, null, "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost/|more?one=1", null, null, "http://localhost/more?one=1")] + [TestCase("http://localhost|/more?one=1", null, "/andMore", "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost|/more?one=1|&two=2|&three=3", null, null, "http://localhost/more?one=1&two=2&three=3")] + [TestCase("http://localhost/more?one=1&two=2|&three=3", null, null, "http://localhost/more?one=1&two=2&three=3")] + [TestCase("http://localhost/more|?one=1&two=2|&three=3", null, null, "http://localhost/more?one=1&two=2&three=3")] + [TestCase("http://localhost/more|?one=1|&two=2&three=3", null, null, "http://localhost/more?one=1&two=2&three=3")] + [TestCase("http://localhost|/more?one=1&two=2&three=3", null, null, "http://localhost/more?one=1&two=2&three=3")] + [TestCase("http://|localhost:12345|/more|?one=1|&two=2|&three=3", null, null, "http://localhost:12345/more?one=1&two=2&three=3")] + [TestCase("http://|localhost|:12345|/more|?one=1|&two=2|&three=3", null, null, "http://localhost:12345/more?one=1&two=2&three=3")] + [TestCase("http://localhost|/more?one=&two=&three=", null, null, "http://localhost/more?one=&two=&three=")] + [TestCase("http://localhost|/more?one&two&three", null, null, "http://localhost/more?one=&two=&three=")] + [TestCase("http://localhost|/more?one&&&&&&&three", null, null, "http://localhost/more?one=&three=")] + public void AppendRawWorks(string rawPart, string host, string pathPart, string expected) + { + var builder = new RawRequestUriBuilder(); + foreach (var c in rawPart.Split('|')) + { + builder.AppendRaw(c, false); + } + + if (host != null) + { + builder.Host = host; + } + + if (pathPart != null) + { + builder.AppendPath(pathPart, false); + } + + Assert.AreEqual(expected, builder.ToUri().ToString()); + } + + [Theory] + [TestCase("http://localhost:12345", "/more/", "andMore", "http://localhost:12345/more/andMore")] + [TestCase("http://localhost", "/more/", "andMore", "http://localhost/more/andMore")] + [TestCase("http://localhost", "/more/", "andMore?one=1", "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost", "/more/", "andMore|?one=1", "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost", "/more/", "andMore|?one=1&two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost", "/more/", "andMore|?one=1|&two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost", "/more/", "andMore?one=1|&two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost/", "more/|andMore", "?one=1", "http://localhost/more/andMore?one=1")] + [TestCase("http://local|host/", "more/|andMore", "?one=1", "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost|/", "more/|andMore", "?one=1", "http://localhost/more/andMore?one=1")] + [TestCase("http://localhost", "more|/|andMore|/", "evenMore?one=1", "http://localhost/more/andMore/evenMore?one=1")] + [TestCase("http://localhost/|more|/", "andMore|/", "evenMore|/|most", "http://localhost/more/andMore/evenMore/most")] + public void AppendRawThenPathThenRaw(string rawPartBefore, string pathPart, string rawPartAfter, string expected) + { + var builder = new RawRequestUriBuilder(); + foreach (var c in rawPartBefore.Split('|')) + { + builder.AppendRaw(c, false); + } + + foreach (var c in pathPart.Split('|')) + { + builder.AppendPath(c, false); + } + + foreach (var c in rawPartAfter.Split('|')) + { + builder.AppendRaw(c, false); + } + + Assert.AreEqual(expected, builder.ToUri().ToString()); + } + + [Theory] + [TestCase("http://localhost/", "one", "1", "more?two=2", "http://localhost/more?one=1&two=2")] + [TestCase("http://localhost:12345/", "one", "1", "more?two=2", "http://localhost:12345/more?one=1&two=2")] + [TestCase("http://localhost/more/", "one", "1", "andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost/more", "one", "1", "/andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost", "one", "1", "/more?two=2", "http://localhost/more?one=1&two=2")] + [TestCase("http://localhost:12345", "one", "1", "/more?two=2", "http://localhost:12345/more?one=1&two=2")] + [TestCase("http://localhost", "one", "1", "/more|/andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost|/more", "one", "1", "/andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost|/|more", "one", "1", "/|andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost|/", "one", "1", "more|/|andMore?two=2", "http://localhost/more/andMore?one=1&two=2")] + [TestCase("http://localhost/", "one", "1", "more/|andMore?two=2|&three=3", "http://localhost/more/andMore?one=1&two=2&three=3")] + [TestCase("http://localhost:12345/", "one", "1", "more/|andMore?two=2|&three=3", "http://localhost:12345/more/andMore?one=1&two=2&three=3")] + public void AppendRawThenQueryThenRaw(string rawPartBefore, string queryName, string queryValue, string rawPartAfter, string expected) + { + var builder = new RawRequestUriBuilder(); + foreach (var c in rawPartBefore.Split('|')) + { + builder.AppendRaw(c, false); + } + + builder.AppendQuery(queryName, queryValue); + + foreach (var c in rawPartAfter.Split('|')) + { + builder.AppendRaw(c, false); + } + + Assert.AreEqual(expected, builder.ToUri().ToString()); + } + + [Theory] + [TestCase(long.MinValue)] + [TestCase(0L)] + [TestCase(long.MaxValue)] + public void AppendPathTypeLong(long longPathPart) + { + const string Endpoint = "http://localhost:12345/getByLong/"; + + var builder = new RawRequestUriBuilder(); + builder.AppendRaw(Endpoint, false); + builder.AppendPath(longPathPart, true); + + Assert.AreEqual($"{Endpoint}{longPathPart:G}", builder.ToUri().ToString()); + } + } +} diff --git a/sdk/core/Azure.Core/tests/RequestContentHelperTests.cs b/sdk/core/Azure.Core/tests/RequestContentHelperTests.cs new file mode 100644 index 000000000000..6f07089b2d64 --- /dev/null +++ b/sdk/core/Azure.Core/tests/RequestContentHelperTests.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Xml; +using System.Xml.Linq; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class RequestContentHelperTests + { + public static IEnumerable GetTimeSpanData() + { + yield return new TestCaseData(XmlConvert.ToTimeSpan("P123DT22H14M12.011S"), XmlConvert.ToTimeSpan("P163DT22H14M12.011S")); + } + + public static IEnumerable GetDateTimeData() + { + yield return new TestCaseData(DateTimeOffset.Parse("2022-08-26T18:38:00Z"), DateTimeOffset.Parse("2022-09-26T18:38:00Z")); + } + + [TestCase(1, 2)] + [TestCase("a", "b")] + [TestCase(true, false)] + [TestCaseSource("GetTimeSpanData")] + [TestCaseSource("GetDateTimeData")] + public void TestGenericFromEnumerable(T expectedValue1, T expectedValue2) + { + var expectedList = new List { expectedValue1, expectedValue2 }; + var content = RequestContentHelper.FromEnumerable(expectedList); + + var stream = new MemoryStream(); + content.WriteTo(stream, default); + stream.Position = 0; + + var document = JsonDocument.Parse(stream); + int count = 0; + foreach (var property in document.RootElement.EnumerateArray()) + { + if (typeof(T) == typeof(int)) + { + Assert.AreEqual(expectedList[count++], property.GetInt32()); + } + else if (typeof(T) == typeof(string)) + { + Assert.AreEqual(expectedList[count++], property.GetString()); + } + else if (typeof(T) == typeof(bool)) + { + Assert.AreEqual(expectedList[count++], property.GetBoolean()); + } + } + } + + [Test] + public void TestBinaryDataFromEnumerable() + { + var expectedList = new List { new BinaryData(1), new BinaryData("\"hello\""), null }; + var content = RequestContentHelper.FromEnumerable(expectedList); + + var stream = new MemoryStream(); + content.WriteTo(stream, default); + stream.Position = 0; + + var document = JsonDocument.Parse(stream); + int count = 0; + foreach (var property in document.RootElement.EnumerateArray()) + { + if (property.ValueKind == JsonValueKind.Null) + { + Assert.IsNull(expectedList[count++]); + } + else + { + Assert.AreEqual(expectedList[count++].ToObjectFromJson(), BinaryData.FromString(property.GetRawText()).ToObjectFromJson()); + } + } + } + + [TestCase(1, 2)] + [TestCase("a", "b")] + [TestCase(true, false)] + [TestCaseSource("GetTimeSpanData")] + [TestCaseSource("GetDateTimeData")] + public void TestGenericFromDictionary(T expectedValue1, T expectedValue2) + { + var expectedDictionary = new Dictionary() + { + {"k1", expectedValue1 }, + {"k2", expectedValue2 } + }; + var content = RequestContentHelper.FromDictionary(expectedDictionary); + + var stream = new MemoryStream(); + content.WriteTo(stream, default); + stream.Position = 0; + + var document = JsonDocument.Parse(stream); + int count = 1; + foreach (var property in document.RootElement.EnumerateObject()) + { + if (typeof(T) == typeof(int)) + { + Assert.AreEqual(expectedDictionary["k" + count++], property.Value.GetInt32()); + } + else if (typeof(T) == typeof(string)) + { + Assert.AreEqual(expectedDictionary["k" + count++], property.Value.GetString()); + } + else if (typeof(T) == typeof(bool)) + { + Assert.AreEqual(expectedDictionary["k" + count++], property.Value.GetBoolean()); + } + } + } + + [Test] + public void TestBinaryDataFromDictionary() + { + var expectedDictionary = new Dictionary() + { + {"k1", new BinaryData(1) }, + {"k2", new BinaryData("\"hello\"") }, + {"k3", null } + }; + + var content = RequestContentHelper.FromDictionary(expectedDictionary); + + var stream = new MemoryStream(); + content.WriteTo(stream, default); + stream.Position = 0; + + var document = JsonDocument.Parse(stream); + int count = 1; + foreach (var property in document.RootElement.EnumerateObject()) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + Assert.IsNull(expectedDictionary["k" + count++]); + } + else + { + Assert.AreEqual(expectedDictionary["k" + count++].ToObjectFromJson(), BinaryData.FromString(property.Value.GetRawText()).ToObjectFromJson()); + } + } + } + } +} diff --git a/sdk/core/Azure.Core/tests/RequestHeaderExtensionsTests.cs b/sdk/core/Azure.Core/tests/RequestHeaderExtensionsTests.cs new file mode 100644 index 000000000000..5d0936a83815 --- /dev/null +++ b/sdk/core/Azure.Core/tests/RequestHeaderExtensionsTests.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Azure.Core.TestFramework; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class RequestHeaderExtensionsTests + { + [TestCaseSource(nameof(ETagRequestHeaderCases))] + public void AddETagRequestHeader(string original, string expected) + { + var request = new MockRequest(); + request.Headers.Add("If-Match", new ETag(original)); + Assert.IsTrue(request.Headers.TryGetValue("If-Match", out string value)); + Assert.AreEqual(expected, value); + } + + [TestCaseSource(nameof(ETagRequestHeaderCases))] + public void AddETagMatchConditionsRequestHeader(string original, string expected) + { + var request = new MockRequest(); + request.Headers.Add(new MatchConditions + { + IfMatch = new ETag(original), + IfNoneMatch = new ETag(original) + }); + Assert.IsTrue(request.Headers.TryGetValue("If-Match", out string ifMatchValue)); + Assert.IsTrue(request.Headers.TryGetValue("If-None-Match", out string ifNoneMatchValue)); + Assert.AreEqual(expected, ifMatchValue); + Assert.AreEqual(expected, ifNoneMatchValue); + } + + [TestCaseSource(nameof(ETagRequestHeaderCases))] + public void AddETagRequestConditionsRequestHeader(string original, string expected) + { + var request = new MockRequest(); + request.Headers.Add(new RequestConditions + { + IfMatch = new ETag(original), + IfNoneMatch = new ETag(original) + }); + Assert.IsTrue(request.Headers.TryGetValue("If-Match", out string ifMatchValue)); + Assert.IsTrue(request.Headers.TryGetValue("If-None-Match", out string ifNoneMatchValue)); + Assert.AreEqual(expected, ifMatchValue); + Assert.AreEqual(expected, ifNoneMatchValue); + } + + private static readonly object[] ETagRequestHeaderCases = + { + new string[] { "*", "*" }, + new string[] { "\"\"", "\"\"" }, + new string[] { "\"abcedfg\"", "\"abcedfg\"" }, + new string[] { "W/\"weakETag\"", "W/\"weakETag\"" }, + new string[] { "abcedfg", "\"abcedfg\"" }, + new string[] { "abcedfg\"", "\"abcedfg\"\""}, + new string[] { "\"abcedfg", "\"\"abcedfg\""}, + new string[] { "W/weakETag\"", "\"W/weakETag\"\"" }, + }; + } +} diff --git a/sdk/core/Azure.Core/tests/TypeFormatterTests.cs b/sdk/core/Azure.Core/tests/TypeFormatterTests.cs new file mode 100644 index 000000000000..d427a7408d9b --- /dev/null +++ b/sdk/core/Azure.Core/tests/TypeFormatterTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text.Json; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class TypeFormatterTests + { + public static object[] DateTimeOffsetCases = + { + new object[] { "O", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 123, default), "2020-05-04T03:02:01.1230000Z" }, + new object[] { "O", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 123, new TimeSpan(1, 0, 0)), "2020-05-04T02:02:01.1230000Z" }, + new object[] { "O", new DateTimeOffset(3155378975999999999, default), "9999-12-31T23:59:59.9999999Z" }, + new object[] { "O", new DateTimeOffset(3155378975999999999, new TimeSpan(1, 0, 0)), "9999-12-31T22:59:59.9999999Z" }, + + new object[] { "o", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 123, default), "2020-05-04T03:02:01.1230000Z" }, + + new object[] { "D", new DateTimeOffset(2020, 05, 04, 0,0,0,0, default), "2020-05-04" }, + + new object[] { "U", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 0, default), "1588561321" }, + + new object[] { "R", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 0, default), "Mon, 04 May 2020 03:02:01 GMT" }, + new object[] { "R", new DateTimeOffset(2020, 05, 04, 03, 02, 01, 0, new TimeSpan(1, 0, 0)), "Mon, 04 May 2020 02:02:01 GMT" }, + }; + + public static object[] TimeSpanCases = + { + new object[] { "P", new TimeSpan(1, 2, 59, 59), "P1DT2H59M59S" }, + new object[] { "c", new TimeSpan(1, 2, 59, 59, 500), "1.02:59:59.5000000" } + }; + + public static object[] BinaryDataCases = + { + new object[] { "D", BinaryData.FromString("test"), "dGVzdA==" }, + new object[] { "U", BinaryData.FromString("test"), "dGVzdA" } + }; + + public static object[] TimeSpanWithoutFormatCases = + { + new object[] { null, new TimeSpan(1, 2, 59, 59), "P1DT2H59M59S" }, + }; + + private static readonly object[] GuidCases = new object[] + { + new object[] { null, Guid.Parse("11111111-1111-1111-1111-111112111111"), "11111111-1111-1111-1111-111112111111" } + }; + + [TestCaseSource(nameof(DateTimeOffsetCases))] + public void FormatsDatesAsString(string format, DateTimeOffset date, string expected) + { + var formatted = TypeFormatters.ToString(date, format); + Assert.AreEqual(expected, formatted); + Assert.AreEqual(date, TypeFormatters.ParseDateTimeOffset(formatted, format)); + } + + [TestCaseSource(nameof(DateTimeOffsetCases))] + public void FormatsDatesAsJson(string format, DateTimeOffset date, string expected) + { + using MemoryStream memoryStream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(memoryStream)) + { + if (format == "U") + { + writer.WriteNumberValue(date, format); + } + else + { + writer.WriteStringValue(date, format); + } + } + + var formatted = JsonDocument.Parse(memoryStream.ToArray()).RootElement; + Assert.AreEqual(expected, formatted.ToString()); + Assert.AreEqual(date, formatted.GetDateTimeOffset(format)); + } + + [TestCase("2020-05-04T03:02:01.1230000+08:00")] + [TestCase("2020-05-04T03:02:01.1230000-08:00")] + [TestCase("2020-05-04T03:02:01.1230000+00:00")] + [TestCase("2020-05-04T03:02:01.1230000")] + [TestCase("Mon, 04 May 2020 03:02:01 GMT")] + [TestCase("Mon, 04 May 2020 03:02:01")] + public void TestEqualAfterConvertingToUtc(string dateString) + { + string[] formats = { "O", "o" }; + + foreach (string format in formats) + { + var originalDate = DateTimeOffset.Parse(dateString); + var originalTimeMillis = originalDate.ToUnixTimeMilliseconds(); + + var formatted = TypeFormatters.ToString(originalDate, format); + var utcDate = DateTimeOffset.Parse(formatted); + Assert.AreEqual(originalTimeMillis, utcDate.ToUnixTimeMilliseconds()); + } + } + + [TestCaseSource(nameof(TimeSpanCases))] + public void FormatsTimeSpanAsJson(string format, TimeSpan duration, string expected) + { + using MemoryStream memoryStream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(memoryStream)) + { + writer.WriteStringValue(duration, format); + } + + var formatted = JsonDocument.Parse(memoryStream.ToArray()).RootElement; + Assert.AreEqual(expected, formatted.ToString()); + Assert.AreEqual(duration, formatted.GetTimeSpan(format)); + } + + [TestCase(null, null, "null")] + [TestCase(null, "str", "str")] + [TestCase(null, true, "true")] + [TestCase(null, false, "false")] + [TestCase(null, 42, "42")] + [TestCase(null, -42, "-42")] + [TestCase(null, 3.14f, "3.14")] + [TestCase(null, -3.14f, "-3.14")] + [TestCase(null, 3.14, "3.14")] + [TestCase(null, -3.14, "-3.14")] + [TestCase(null, 299792458L, "299792458")] + [TestCase(null, -299792458L, "-299792458")] + [TestCase("D", new byte[] { 1, 2, 3 }, "AQID")] + [TestCase("U", new byte[] { 4, 5, 6 }, "BAUG")] + [TestCase(null, new string[] { "a", "b" }, "a,b")] + [TestCaseSource(nameof(DateTimeOffsetCases))] + [TestCaseSource(nameof(TimeSpanWithoutFormatCases))] + [TestCaseSource(nameof(TimeSpanCases))] + [TestCaseSource(nameof(GuidCases))] + [TestCaseSource(nameof(BinaryDataCases))] + public void ValidateConvertToString(string format, object value, string expected) + { + var result = TypeFormatters.ConvertToString(value, format); + + Assert.AreEqual(expected, result); + } + } +} diff --git a/sdk/core/Azure.Core/tests/WriterExtensionTests.cs b/sdk/core/Azure.Core/tests/WriterExtensionTests.cs new file mode 100644 index 000000000000..d2ee646346cf --- /dev/null +++ b/sdk/core/Azure.Core/tests/WriterExtensionTests.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class WriterExtensionTests + { + private static object[] ObjectCases = + { + new object[] { null, "null" }, + new object[] { new byte[] { 1, 2, 3, 4}, @"""AQIDBA==""" }, + new object[] { 42, "42" }, + new object[] { 42.0m, "42.0" }, + new object[] { 42.0d, "42" }, + new object[] { 42.0f, "42" }, + new object[] { 42L, "42" }, + new object[] { "asdf", @"""asdf""" }, + new object[] { false, "false" }, + new object[] { new Guid ("c6125705-61d7-4cd6-8d6c-f7f247a7a5fa"), @"""c6125705-61d7-4cd6-8d6c-f7f247a7a5fa""" }, + new object[] { new BinaryData (new byte[] { 1, 2, 3, 4}), @"""AQIDBA==""" }, + new object[] { new DateTimeOffset (2001, 1, 1, 1, 1, 1, 1, new TimeSpan ()), @"""2001-01-01T01:01:01.0010000Z""" }, + new object[] { new DateTime (2001, 1, 1, 1, 1, 1, DateTimeKind.Utc), @"""2001-01-01T01:01:01.0000000Z""" }, + new object[] { (IEnumerable)new List() { 1, 2, 3, 4 }, "[1,2,3,4]" }, + new object[] { (IEnumerable>)new List>() { + new KeyValuePair ("a", (object)1), + new KeyValuePair ("b", (object)2), + new KeyValuePair ("c", (object)3), + new KeyValuePair ("d", (object)4), + }, + @"{""a"":1,""b"":2,""c"":3,""d"":4}" + }, + }; + + [Test, TestCaseSource("ObjectCases")] + public static void WriteObjectValue (object value, string expectedJson) + { + using MemoryStream stream = new MemoryStream (); + using Utf8JsonWriter writer = new Utf8JsonWriter (stream); + writer.WriteObjectValue (value); + + writer.Flush (); + Assert.AreEqual (expectedJson, System.Text.Encoding.UTF8.GetString(stream.ToArray())); + } + + [Test] + public static void WriteObjectValueJsonElement () + { + using MemoryStream stream = new MemoryStream (); + using Utf8JsonWriter writer = new Utf8JsonWriter (stream); + + string json = @"{""TablesToMove"": [{""TableName"":""TestTable""}]}"; + Dictionary content = JsonSerializer.Deserialize>(json); + JsonElement element = (JsonElement)content["TablesToMove"]; + writer.WriteObjectValue (element); + writer.Flush (); + + Assert.AreEqual (@"[{""TableName"":""TestTable""}]", System.Text.Encoding.UTF8.GetString(stream.ToArray())); + } + + [Test] + public static void WriteObjectValueIUtf8JsonSerializable () + { + using MemoryStream stream = new MemoryStream (); + using Utf8JsonWriter writer = new Utf8JsonWriter (stream); + + var content = new TestSerialize (); + writer.WriteObjectValue (content); + Assert.True (content.didWrite); + } + + [Test] + public static void WriteObjectValueNullIUtf8JsonSerializable () + { + using MemoryStream stream = new MemoryStream (); + using Utf8JsonWriter writer = new Utf8JsonWriter (stream); + + TestSerialize content = null; + writer.WriteObjectValue(content); + + writer.Flush(); + Assert.AreEqual("null", System.Text.Encoding.UTF8.GetString(stream.ToArray())); + } + + internal class TestSerialize : IUtf8JsonSerializable + { + internal bool didWrite = false; + + public void Write(Utf8JsonWriter writer) + { + didWrite = true; + } + } + } +} diff --git a/sdk/core/Azure.Core/tests/XMLWriterContentTests.cs b/sdk/core/Azure.Core/tests/XMLWriterContentTests.cs new file mode 100644 index 000000000000..199a03e739e5 --- /dev/null +++ b/sdk/core/Azure.Core/tests/XMLWriterContentTests.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class XMLWriterContentTests + { + [Test] + public void DisposeDoesNotThrow() + { + var writer = new XmlWriterContent(); + writer.Dispose(); + } + } +}