diff --git a/sdk/core/System.ClientModel/CHANGELOG.md b/sdk/core/System.ClientModel/CHANGELOG.md index f80d1796f748..8419e820c207 100644 --- a/sdk/core/System.ClientModel/CHANGELOG.md +++ b/sdk/core/System.ClientModel/CHANGELOG.md @@ -1,11 +1,12 @@ # Release History -## 1.1.0 (2024-09-12) +## 1.1.0 (2024-09-17) ### Other Changes - Removed implicit cast from `string` to `ApiKeyCredential` ([#45554](https://github.com/Azure/azure-sdk-for-net/pull/45554)). - Upgraded `System.Text.Json` package dependency to 6.0.9 ([#45416](https://github.com/Azure/azure-sdk-for-net/pull/45416)). +- Removed `PageCollection` and related types in favor of using `CollectionResult` and related types as the return values from paginated service endpoints ([#45961](https://github.com/Azure/azure-sdk-for-net/pull/45961)). ## 1.1.0-beta.7 (2024-08-14) diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index a510e31d07d7..675d7c513786 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -6,20 +6,11 @@ public ApiKeyCredential(string key) { } public void Deconstruct(out string key) { throw null; } public void Update(string key) { } } - public abstract partial class AsyncCollectionResult : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable + public abstract partial class AsyncCollectionResult : System.ClientModel.Primitives.AsyncCollectionResult, System.Collections.Generic.IAsyncEnumerable { protected internal AsyncCollectionResult() { } - protected internal AsyncCollectionResult(System.ClientModel.Primitives.PipelineResponse response) { } - public abstract System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); - } - public abstract partial class AsyncPageCollection : System.Collections.Generic.IAsyncEnumerable> - { - protected AsyncPageCollection() { } - public System.Collections.Generic.IAsyncEnumerable GetAllValuesAsync([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - protected abstract System.Collections.Generic.IAsyncEnumerator> GetAsyncEnumeratorCore(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); - public System.Threading.Tasks.Task> GetCurrentPageAsync() { throw null; } - protected abstract System.Threading.Tasks.Task> GetCurrentPageAsyncCore(); - System.Collections.Generic.IAsyncEnumerator> System.Collections.Generic.IAsyncEnumerable>.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; } + public System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + protected abstract System.Collections.Generic.IAsyncEnumerable GetValuesFromPageAsync(System.ClientModel.ClientResult page); } public abstract partial class BinaryContent : System.IDisposable { @@ -34,13 +25,11 @@ protected BinaryContent() { } } public partial class ClientResult { - protected ClientResult() { } protected ClientResult(System.ClientModel.Primitives.PipelineResponse response) { } public static System.ClientModel.ClientResult FromOptionalValue(T? value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromResponse(System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromValue(T value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } - protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } } public partial class ClientResultException : System.Exception { @@ -52,15 +41,15 @@ public ClientResultException(string message, System.ClientModel.Primitives.Pipel } public partial class ClientResult : System.ClientModel.ClientResult { - protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) { } + protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) : base (default(System.ClientModel.Primitives.PipelineResponse)) { } public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } - public abstract partial class CollectionResult : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable + public abstract partial class CollectionResult : System.ClientModel.Primitives.CollectionResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable { protected internal CollectionResult() { } - protected internal CollectionResult(System.ClientModel.Primitives.PipelineResponse response) { } - public abstract System.Collections.Generic.IEnumerator GetEnumerator(); + public System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + protected abstract System.Collections.Generic.IEnumerable GetValuesFromPage(System.ClientModel.ClientResult page); System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } } public partial class ContinuationToken @@ -70,24 +59,6 @@ protected ContinuationToken(System.BinaryData bytes) { } public static System.ClientModel.ContinuationToken FromBytes(System.BinaryData bytes) { throw null; } public virtual System.BinaryData ToBytes() { throw null; } } - public abstract partial class PageCollection : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable - { - protected PageCollection() { } - public System.Collections.Generic.IEnumerable GetAllValues() { throw null; } - public System.ClientModel.PageResult GetCurrentPage() { throw null; } - protected abstract System.ClientModel.PageResult GetCurrentPageCore(); - protected abstract System.Collections.Generic.IEnumerator> GetEnumeratorCore(); - System.Collections.Generic.IEnumerator> System.Collections.Generic.IEnumerable>.GetEnumerator() { throw null; } - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } - } - public partial class PageResult : System.ClientModel.ClientResult - { - internal PageResult() { } - public System.ClientModel.ContinuationToken? NextPageToken { get { throw null; } } - public System.ClientModel.ContinuationToken PageToken { get { throw null; } } - public System.Collections.Generic.IReadOnlyList Values { get { throw null; } } - public static System.ClientModel.PageResult Create(System.Collections.Generic.IReadOnlyList values, System.ClientModel.ContinuationToken pageToken, System.ClientModel.ContinuationToken? nextPageToken, System.ClientModel.Primitives.PipelineResponse response) { throw null; } - } } namespace System.ClientModel.Primitives { @@ -100,6 +71,12 @@ internal ApiKeyAuthenticationPolicy() { } public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } } + public abstract partial class AsyncCollectionResult + { + protected AsyncCollectionResult() { } + public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page); + public abstract System.Collections.Generic.IAsyncEnumerable GetRawPagesAsync(); + } [System.FlagsAttribute] public enum ClientErrorBehaviors { @@ -142,6 +119,12 @@ public sealed override void Process(System.ClientModel.Primitives.PipelineMessag protected virtual void Wait(System.TimeSpan time, System.Threading.CancellationToken cancellationToken) { } protected virtual System.Threading.Tasks.Task WaitAsync(System.TimeSpan time, System.Threading.CancellationToken cancellationToken) { throw null; } } + public abstract partial class CollectionResult + { + protected CollectionResult() { } + public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page); + public abstract System.Collections.Generic.IEnumerable GetRawPages(); + } public partial class HttpClientPipelineTransport : System.ClientModel.Primitives.PipelineTransport, System.IDisposable { public HttpClientPipelineTransport() { } @@ -189,11 +172,13 @@ public ModelReaderWriterOptions(string format) { } public static System.ClientModel.Primitives.ModelReaderWriterOptions Json { get { throw null; } } public static System.ClientModel.Primitives.ModelReaderWriterOptions Xml { get { throw null; } } } - public abstract partial class OperationResult : System.ClientModel.ClientResult + public abstract partial class OperationResult { protected OperationResult(System.ClientModel.Primitives.PipelineResponse response) { } public bool HasCompleted { get { throw null; } protected set { } } public abstract System.ClientModel.ContinuationToken? RehydrationToken { get; protected set; } + public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } + protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } public abstract System.ClientModel.ClientResult UpdateStatus(System.ClientModel.Primitives.RequestOptions? options = null); public abstract System.Threading.Tasks.ValueTask UpdateStatusAsync(System.ClientModel.Primitives.RequestOptions? options = null); public virtual void WaitForCompletion(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index ad3a7a378fc0..ce85c5d8b83d 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -6,20 +6,11 @@ public ApiKeyCredential(string key) { } public void Deconstruct(out string key) { throw null; } public void Update(string key) { } } - public abstract partial class AsyncCollectionResult : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable + public abstract partial class AsyncCollectionResult : System.ClientModel.Primitives.AsyncCollectionResult, System.Collections.Generic.IAsyncEnumerable { protected internal AsyncCollectionResult() { } - protected internal AsyncCollectionResult(System.ClientModel.Primitives.PipelineResponse response) { } - public abstract System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); - } - public abstract partial class AsyncPageCollection : System.Collections.Generic.IAsyncEnumerable> - { - protected AsyncPageCollection() { } - public System.Collections.Generic.IAsyncEnumerable GetAllValuesAsync([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - protected abstract System.Collections.Generic.IAsyncEnumerator> GetAsyncEnumeratorCore(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); - public System.Threading.Tasks.Task> GetCurrentPageAsync() { throw null; } - protected abstract System.Threading.Tasks.Task> GetCurrentPageAsyncCore(); - System.Collections.Generic.IAsyncEnumerator> System.Collections.Generic.IAsyncEnumerable>.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; } + public System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + protected abstract System.Collections.Generic.IAsyncEnumerable GetValuesFromPageAsync(System.ClientModel.ClientResult page); } public abstract partial class BinaryContent : System.IDisposable { @@ -34,13 +25,11 @@ protected BinaryContent() { } } public partial class ClientResult { - protected ClientResult() { } protected ClientResult(System.ClientModel.Primitives.PipelineResponse response) { } public static System.ClientModel.ClientResult FromOptionalValue(T? value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromResponse(System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromValue(T value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } - protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } } public partial class ClientResultException : System.Exception { @@ -52,15 +41,15 @@ public ClientResultException(string message, System.ClientModel.Primitives.Pipel } public partial class ClientResult : System.ClientModel.ClientResult { - protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) { } + protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) : base (default(System.ClientModel.Primitives.PipelineResponse)) { } public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } - public abstract partial class CollectionResult : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable + public abstract partial class CollectionResult : System.ClientModel.Primitives.CollectionResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable { protected internal CollectionResult() { } - protected internal CollectionResult(System.ClientModel.Primitives.PipelineResponse response) { } - public abstract System.Collections.Generic.IEnumerator GetEnumerator(); + public System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + protected abstract System.Collections.Generic.IEnumerable GetValuesFromPage(System.ClientModel.ClientResult page); System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } } public partial class ContinuationToken @@ -70,24 +59,6 @@ protected ContinuationToken(System.BinaryData bytes) { } public static System.ClientModel.ContinuationToken FromBytes(System.BinaryData bytes) { throw null; } public virtual System.BinaryData ToBytes() { throw null; } } - public abstract partial class PageCollection : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable - { - protected PageCollection() { } - public System.Collections.Generic.IEnumerable GetAllValues() { throw null; } - public System.ClientModel.PageResult GetCurrentPage() { throw null; } - protected abstract System.ClientModel.PageResult GetCurrentPageCore(); - protected abstract System.Collections.Generic.IEnumerator> GetEnumeratorCore(); - System.Collections.Generic.IEnumerator> System.Collections.Generic.IEnumerable>.GetEnumerator() { throw null; } - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } - } - public partial class PageResult : System.ClientModel.ClientResult - { - internal PageResult() { } - public System.ClientModel.ContinuationToken? NextPageToken { get { throw null; } } - public System.ClientModel.ContinuationToken PageToken { get { throw null; } } - public System.Collections.Generic.IReadOnlyList Values { get { throw null; } } - public static System.ClientModel.PageResult Create(System.Collections.Generic.IReadOnlyList values, System.ClientModel.ContinuationToken pageToken, System.ClientModel.ContinuationToken? nextPageToken, System.ClientModel.Primitives.PipelineResponse response) { throw null; } - } } namespace System.ClientModel.Primitives { @@ -100,6 +71,12 @@ internal ApiKeyAuthenticationPolicy() { } public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } } + public abstract partial class AsyncCollectionResult + { + protected AsyncCollectionResult() { } + public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page); + public abstract System.Collections.Generic.IAsyncEnumerable GetRawPagesAsync(); + } [System.FlagsAttribute] public enum ClientErrorBehaviors { @@ -142,6 +119,12 @@ public sealed override void Process(System.ClientModel.Primitives.PipelineMessag protected virtual void Wait(System.TimeSpan time, System.Threading.CancellationToken cancellationToken) { } protected virtual System.Threading.Tasks.Task WaitAsync(System.TimeSpan time, System.Threading.CancellationToken cancellationToken) { throw null; } } + public abstract partial class CollectionResult + { + protected CollectionResult() { } + public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page); + public abstract System.Collections.Generic.IEnumerable GetRawPages(); + } public partial class HttpClientPipelineTransport : System.ClientModel.Primitives.PipelineTransport, System.IDisposable { public HttpClientPipelineTransport() { } @@ -188,11 +171,13 @@ public ModelReaderWriterOptions(string format) { } public static System.ClientModel.Primitives.ModelReaderWriterOptions Json { get { throw null; } } public static System.ClientModel.Primitives.ModelReaderWriterOptions Xml { get { throw null; } } } - public abstract partial class OperationResult : System.ClientModel.ClientResult + public abstract partial class OperationResult { protected OperationResult(System.ClientModel.Primitives.PipelineResponse response) { } public bool HasCompleted { get { throw null; } protected set { } } public abstract System.ClientModel.ContinuationToken? RehydrationToken { get; protected set; } + public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } + protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } public abstract System.ClientModel.ClientResult UpdateStatus(System.ClientModel.Primitives.RequestOptions? options = null); public abstract System.Threading.Tasks.ValueTask UpdateStatusAsync(System.ClientModel.Primitives.RequestOptions? options = null); public virtual void WaitForCompletion(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { } diff --git a/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResult.cs b/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResult.cs new file mode 100644 index 000000000000..602d27ea3806 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResult.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading; + +namespace System.ClientModel.Primitives; + +/// +/// Represents a collection of values returned from a cloud service operation. +/// The collection values may be delivered over one or more service responses. +/// +public abstract class AsyncCollectionResult +{ + /// + /// Creates a new instance of . + /// + protected AsyncCollectionResult() + { + } + + /// + /// Gets the collection of page responses that contain the items in this + /// collection. + /// + /// A collection of service responses where each + /// holds a subset of items in the full + /// collection. + /// + /// This method does not take a + /// parameter. implementations must + /// store the passed to the service method + /// that creates them and pass that token to any async methods + /// called from this method. For protocol methods, this + /// will come from the + /// property. + public abstract IAsyncEnumerable GetRawPagesAsync(); + + /// + /// Gets a that can be passed to a client + /// method to obtain a collection holding the items remaining in this + /// . + /// + /// The raw page to obtain a continuation token for. + /// + /// A that a client can use to + /// obtain an whose items start at the + /// first item after the last item in , or + /// null if is the last page in the sequence + /// of page responses delivering the items in the collection. + public abstract ContinuationToken? GetContinuationToken(ClientResult page); +} diff --git a/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResultOfT.cs b/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResultOfT.cs index cada18b25ba4..9e84b2dd772e 100644 --- a/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResultOfT.cs +++ b/sdk/core/System.ClientModel/src/Convenience/AsyncCollectionResultOfT.cs @@ -4,41 +4,46 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.Threading; +using System.Threading.Tasks; namespace System.ClientModel; /// /// Represents a collection of values returned from a cloud service operation. -/// The collection values may be returned by one or more service responses. +/// The collection values may be delivered over one or more service responses. /// -public abstract class AsyncCollectionResult : ClientResult, IAsyncEnumerable +public abstract class AsyncCollectionResult : AsyncCollectionResult, IAsyncEnumerable { /// - /// Create a new instance of . + /// Creates a new instance of . /// - /// If no is provided when the - /// instance is created, it is expected that - /// a derived type will call - /// prior to a user calling . - /// This constructor is indended for use by collection implementations that - /// postpone sending a request until - /// is called. Such implementations will typically be returned from client - /// convenience methods so that callers of the methods don't need to - /// dispose the return value. - protected internal AsyncCollectionResult() : base() + protected internal AsyncCollectionResult() { } - /// - /// Create a new instance of . - /// - /// The holding the - /// items in the collection, or the first set of the items in the collection. - /// - protected internal AsyncCollectionResult(PipelineResponse response) : base(response) + /// + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { + await foreach (ClientResult page in GetRawPagesAsync().ConfigureAwait(false).WithCancellation(cancellationToken)) + { + await foreach (T value in GetValuesFromPageAsync(page).ConfigureAwait(false).WithCancellation(cancellationToken)) + { + yield return value; + } + } } - /// - public abstract IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default); + /// + /// Gets a collection of the values returned in a page response. + /// + /// The service response to obtain the values from. + /// + /// A collection of values read from the + ///response content in . + /// This method does not take a + /// parameter. implementations must + /// store the passed to the service method + /// that creates them and pass that token to any async methods + /// called from this method. + protected abstract IAsyncEnumerable GetValuesFromPageAsync(ClientResult page); } diff --git a/sdk/core/System.ClientModel/src/Convenience/AsyncPageCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/AsyncPageCollectionOfT.cs deleted file mode 100644 index 7786b60d70f6..000000000000 --- a/sdk/core/System.ClientModel/src/Convenience/AsyncPageCollectionOfT.cs +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections.Generic; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace System.ClientModel; - -/// -/// An asynchronous collection of page results returned by a cloud service. -/// Cloud services use pagination to return a collection of items over multiple -/// responses. Each response from the service returns a page of items in the -/// collection, as well as the information needed to obtain the next page of -/// items, until all the items in the requested collection have been returned. -/// To enumerate the items in the collection, instead of the pages in the -/// collection, call . To get the current -/// collection page, call . -/// -public abstract class AsyncPageCollection : IAsyncEnumerable> -{ - /// - /// Create a new instance of . - /// - protected AsyncPageCollection() : base() - { - // Note that page collections delay making a first request until either - // GetCurrentPageAsync is called or the collection returned by - // GetAllValuesAsync is enumerated, so this constructor calls the base - // class constructor that does not take a PipelineResponse. - } - - /// - /// Get the current page of the collection. - /// - /// The current page in the collection. - public async Task> GetCurrentPageAsync() - => await GetCurrentPageAsyncCore().ConfigureAwait(false); - - /// - /// Get a collection of all the values in the collection requested from the - /// cloud service, rather than the pages of values. - /// - /// The values requested from the cloud service. - public async IAsyncEnumerable GetAllValuesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await foreach (PageResult page in this.WithCancellation(cancellationToken).ConfigureAwait(false)) - { - foreach (T value in page.Values) - { - cancellationToken.ThrowIfCancellationRequested(); - - yield return value; - } - } - } - - /// - /// Get the current page of the collection. - /// - /// The current page in the collection. - protected abstract Task> GetCurrentPageAsyncCore(); - - /// - /// Get an async enumerator that can enumerate the pages of values returned - /// by the cloud service. - /// - /// An async enumerator of pages holding the items in the value - /// collection. - protected abstract IAsyncEnumerator> GetAsyncEnumeratorCore(CancellationToken cancellationToken = default); - - IAsyncEnumerator> IAsyncEnumerable>.GetAsyncEnumerator(CancellationToken cancellationToken) - => GetAsyncEnumeratorCore(cancellationToken); -} diff --git a/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs b/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs index 77e7b3cbf94d..c31c9175f8e1 100644 --- a/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs +++ b/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs @@ -11,21 +11,10 @@ namespace System.ClientModel; /// public class ClientResult { - private PipelineResponse? _response; + private readonly PipelineResponse _response; /// - /// Create a new instance of . - /// - /// If no is provided when the - /// instance is created, it is expected that - /// a derived type will call - /// prior to a user calling . - protected ClientResult() - { - } - - /// - /// Create a new instance of from a service + /// Creates a new instance of from a service /// response. /// /// The received @@ -40,41 +29,9 @@ protected ClientResult(PipelineResponse response) /// /// Gets the received from the service. /// - /// the received from the service. + /// The received from the service. /// - /// No - /// value is currently available for this - /// instance. This can happen when the instance - /// is a collection type like - /// that has not yet been enumerated. - public PipelineResponse GetRawResponse() - { - if (_response is null) - { - throw new InvalidOperationException("No response is associated " + - "with this result. If the result is a collection result " + - "type, this may be because no request has been sent to the " + - "server yet."); - } - - return _response; - } - - /// - /// Update the value returned from . - /// - /// This method may be called from types derived from - /// that poll the service for status updates - /// or to retrieve additional collection values to update the raw response - /// to the response most recently returned from the service. - /// The to return - /// from . - protected void SetRawResponse(PipelineResponse response) - { - Argument.AssertNotNull(response, nameof(response)); - - _response = response; - } + public PipelineResponse GetRawResponse() => _response; #region Factory methods for ClientResult and subtypes @@ -112,7 +69,7 @@ public static ClientResult FromValue(T value, PipelineResponse response) if (value is null) { string message = "ClientResult contract guarantees that ClientResult.Value is non-null. " + - "If you need to return a ClientResult where the Value is null, please use call ClientResult.FromOptionalValue instead."; + "If you need to return a ClientResult where the Value is null, please use ClientResult.FromOptionalValue instead."; throw new ArgumentNullException(nameof(value), message); } diff --git a/sdk/core/System.ClientModel/src/Convenience/CollectionResult.cs b/sdk/core/System.ClientModel/src/Convenience/CollectionResult.cs new file mode 100644 index 000000000000..3eb1aa3e2629 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/CollectionResult.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading; + +namespace System.ClientModel.Primitives; + +/// +/// Represents a collection of values returned from a cloud service operation. +/// The collection values may be delivered over one or more service responses. +/// +public abstract class CollectionResult +{ /// + /// Creates a new instance of . + /// + protected CollectionResult() + { + } + + /// + /// Gets the collection of page responses that contain the items in this + /// collection. + /// + /// A collection of service responses where each + /// holds a subset of items in the full + /// collection. + /// + /// implementations are expected + /// to store the passed to the service + /// method that creates them and pass that token to any methods making + /// service calls that are called from this method. For protocol methods, + /// this will come from the + /// property. + public abstract IEnumerable GetRawPages(); + + /// + /// Gets a that can be passed to a client + /// method to obtain a collection holding the remaining items in this + /// . + /// + /// The raw page to obtain a continuation token for. + /// + /// A that a client can use to + /// obtain an whose items start at the + /// first item after the last item in , or + /// null if is the last page in the sequence + /// of page responses delivering the items in the collection. + public abstract ContinuationToken? GetContinuationToken(ClientResult page); +} diff --git a/sdk/core/System.ClientModel/src/Convenience/CollectionResultOfT.cs b/sdk/core/System.ClientModel/src/Convenience/CollectionResultOfT.cs index bf51575ef9d2..9289985159aa 100644 --- a/sdk/core/System.ClientModel/src/Convenience/CollectionResultOfT.cs +++ b/sdk/core/System.ClientModel/src/Convenience/CollectionResultOfT.cs @@ -4,43 +4,48 @@ using System.ClientModel.Primitives; using System.Collections; using System.Collections.Generic; +using System.Threading; namespace System.ClientModel; /// /// Represents a collection of values returned from a cloud service operation. -/// The collection values may be returned by one or more service responses. +/// The collection values may be delivered over one or more service responses. /// -public abstract class CollectionResult : ClientResult, IEnumerable +public abstract class CollectionResult : CollectionResult, IEnumerable { /// - /// Create a new instance of . + /// Creates a new instance of . /// - /// If no is provided when the - /// instance is created, it is expected that - /// a derived type will call - /// prior to a user calling . - /// This constructor is indended for use by collection implementations that - /// postpone sending a request until - /// is called. Such implementations will typically be returned from client - /// convenience methods so that callers of the methods don't need to - /// dispose the return value. - protected internal CollectionResult() : base() + protected internal CollectionResult() { } - /// - /// Create a new instance of . - /// - /// The holding the - /// items in the collection, or the first set of the items in the collection. - /// - protected internal CollectionResult(PipelineResponse response) : base(response) + /// + public IEnumerator GetEnumerator() { + foreach (ClientResult page in GetRawPages()) + { + foreach (T value in GetValuesFromPage(page)) + { + yield return value; + } + } } - /// - public abstract IEnumerator GetEnumerator(); + /// + /// Gets a collection of the values returned in a page response. + /// + /// The service response to obtain the values from. + /// + /// A collection of values read from the + ///response content in . + /// implementations are expected + /// to store the passed to the service + /// method that creates them and pass that token to any methods making + /// service calls that are called from this method. + protected abstract IEnumerable GetValuesFromPage(ClientResult page); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } +#pragma warning restore CS1591 // public XML comments diff --git a/sdk/core/System.ClientModel/src/Convenience/OperationResult.cs b/sdk/core/System.ClientModel/src/Convenience/OperationResult.cs index facf771dc803..93388d3c35fd 100644 --- a/sdk/core/System.ClientModel/src/Convenience/OperationResult.cs +++ b/sdk/core/System.ClientModel/src/Convenience/OperationResult.cs @@ -16,19 +16,18 @@ namespace System.ClientModel.Primitives; /// properties such as Value or Status as applicable /// for a given service operation. /// -public abstract class OperationResult : ClientResult +public abstract class OperationResult { + private PipelineResponse _response; + /// - /// Create a new instance of . + /// Creates a new instance of . /// /// The received from /// the service in response to the request that started the operation. - /// Derived types will call - /// when a new - /// response is received that updates the status of the operation. protected OperationResult(PipelineResponse response) - : base(response) { + _response = response; } /// @@ -54,6 +53,9 @@ protected OperationResult(PipelineResponse response) /// A token that can be used to rehydrate the operation, for example /// to monitor its progress or to obtain its final result, from a process /// different than the one that started the operation. + /// This property is abstract so that derived types that do not + /// support rehydration can return null without using a backing field for + /// an unused . public abstract ContinuationToken? RehydrationToken { get; protected set; } /// @@ -159,4 +161,25 @@ public virtual void WaitForCompletion(CancellationToken cancellationToken = defa SetRawResponse(result.GetRawResponse()); } } + + /// + /// Gets the corresponding to the most + /// recent update received from the service. + /// + /// The most recent received + /// from the service. + /// + public PipelineResponse GetRawResponse() => _response; + + /// + /// Update the value returned from . + /// + /// The to return + /// from . + protected void SetRawResponse(PipelineResponse response) + { + Argument.AssertNotNull(response, nameof(response)); + + _response = response; + } } diff --git a/sdk/core/System.ClientModel/src/Convenience/PageCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/PageCollectionOfT.cs deleted file mode 100644 index 6ee7e1720162..000000000000 --- a/sdk/core/System.ClientModel/src/Convenience/PageCollectionOfT.cs +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections; -using System.Collections.Generic; - -namespace System.ClientModel; - -/// -/// A collection of page results returned by a cloud service. Cloud services -/// use pagination to return a collection of items over multiple responses. -/// Each response from the service returns a page of items in the collection, -/// as well as the information needed to obtain the next page of items, until -/// all the items in the requested collection have been returned. To enumerate -/// the items in the collection, instead of the pages in the collection, call -/// . To get the current collection page, call -/// . -/// -public abstract class PageCollection : IEnumerable> -{ - /// - /// Create a new instance of . - /// - protected PageCollection() : base() - { - // Note that page collections delay making a first request until either - // GetCurrentPage is called or the collection returned by GetAllValues - // is enumerated, so this constructor calls the base class constructor - // that does not take a PipelineResponse. - } - - /// - /// Get the current page of the collection. - /// - /// The current page in the collection. - public PageResult GetCurrentPage() - => GetCurrentPageCore(); - - /// - /// Get a collection of all the values in the collection requested from the - /// cloud service, rather than the pages of values. - /// - /// The values requested from the cloud service. - public IEnumerable GetAllValues() - { - foreach (PageResult page in this) - { - foreach (T value in page.Values) - { - yield return value; - } - } - } - - /// - /// Get the current page of the collection. - /// - /// The current page in the collection. - protected abstract PageResult GetCurrentPageCore(); - - /// - /// Get an enumerator that can enumerate the pages of values returned by - /// the cloud service. - /// - /// An enumerator of pages holding the items in the value - /// collection. - protected abstract IEnumerator> GetEnumeratorCore(); - - IEnumerator> IEnumerable>.GetEnumerator() - => GetEnumeratorCore(); - - IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable>)this).GetEnumerator(); -} diff --git a/sdk/core/System.ClientModel/src/Convenience/PageResultOfT.cs b/sdk/core/System.ClientModel/src/Convenience/PageResultOfT.cs deleted file mode 100644 index c8aa4ca3b77d..000000000000 --- a/sdk/core/System.ClientModel/src/Convenience/PageResultOfT.cs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.ClientModel.Primitives; -using System.Collections.Generic; - -namespace System.ClientModel; - -/// -/// A page of values returned from a cloud service in response to a request for -/// a collection from a paginated endpoint. When used with -/// or , -/// returns a subset of a complete collection of values returned from a service. -/// Each represents the values in a single service -/// response. -/// -public class PageResult : ClientResult -{ - private PageResult(IReadOnlyList values, - ContinuationToken pageToken, - ContinuationToken? nextPageToken, - PipelineResponse response) : base(response) - { - Argument.AssertNotNull(values, nameof(values)); - Argument.AssertNotNull(pageToken, nameof(pageToken)); - - Values = values; - PageToken = pageToken; - NextPageToken = nextPageToken; - } - - /// - /// Gets the values in this . - /// - public IReadOnlyList Values { get; } - - /// - /// Gets a token that can be passed to a client method to obtain a page - /// collection that begins with this page of values. - /// - /// for more details. - public ContinuationToken PageToken { get; } - - /// - /// Gets a token that can be passed to a client method to obtain a page - /// collection that begins with the page of values after this page. If - /// is null, the current page is the last page - /// in the page collection. - /// - /// for more details. - public ContinuationToken? NextPageToken { get; } - - /// - /// Create a from the provided parameters. - /// - /// The values in the . - /// - /// A token that can be used to request a collection - /// beginning with this page of values. - /// A token that can be used to request a - /// collection beginning with the next page of values. - /// The response that returned the values in the - /// page. - /// A holding the provided values. - /// - public static PageResult Create(IReadOnlyList values, ContinuationToken pageToken, ContinuationToken? nextPageToken, PipelineResponse response) - => new(values, pageToken, nextPageToken, response); -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs b/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs deleted file mode 100644 index e7efca1805b1..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace System.ClientModel.Internal; - -/// -/// Represents a collection of SSE events that can be enumerated as a C# async stream. -/// -internal class AsyncServerSentEventEnumerable : IAsyncEnumerable -{ - private readonly Stream _contentStream; - - public AsyncServerSentEventEnumerable(Stream contentStream) - { - Argument.AssertNotNull(contentStream, nameof(contentStream)); - - _contentStream = contentStream; - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - return new AsyncServerSentEventEnumerator(_contentStream, this, cancellationToken); - } - - private sealed class AsyncServerSentEventEnumerator : IAsyncEnumerator - { - private readonly ServerSentEventReader _reader; - private readonly AsyncServerSentEventEnumerable _enumerable; - private readonly CancellationToken _cancellationToken; - - public ServerSentEvent Current { get; private set; } - - public AsyncServerSentEventEnumerator(Stream contentStream, - AsyncServerSentEventEnumerable enumerable, - CancellationToken cancellationToken = default) - { - _reader = new(contentStream); - _enumerable = enumerable; - _cancellationToken = cancellationToken; - } - - public async ValueTask MoveNextAsync() - { - ServerSentEvent? nextEvent = await _reader.TryGetNextEventAsync(_cancellationToken).ConfigureAwait(false); - _enumerable.LastEventId = _reader.LastEventId; - _enumerable.ReconnectionInterval = _reader.ReconnectionInterval; - - if (nextEvent.HasValue) - { - Current = nextEvent.Value; - return true; - } - - Current = default; - return false; - } - - public ValueTask DisposeAsync() - { - // The creator of the enumerable has responsibility for disposing - // the content stream passed to the enumerable constructor. - -#if NET6_0_OR_GREATER - return ValueTask.CompletedTask; -#else - return new ValueTask(); -#endif - } - } -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs deleted file mode 100644 index f962dd2bac4b..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace System.ClientModel.Internal; - -/// -/// Represents an SSE event. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal readonly struct ServerSentEvent -{ - // Gets the value of the SSE "event type" buffer, used to distinguish - // between event kinds. - public string EventType { get; } - - // Gets the value of the SSE "data" buffer, which holds the payload of the - // server-sent event. - public string Data { get; } - - public ServerSentEvent(string type, string data) - { - EventType = type; - Data = data; - } -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs deleted file mode 100644 index 8c0ebca65681..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Threading; - -namespace System.ClientModel.Internal; - -/// -/// Represents a collection of SSE events that can be enumerated as a C# collection. -/// -internal class ServerSentEventEnumerable : IEnumerable -{ - private readonly Stream _contentStream; - - public ServerSentEventEnumerable(Stream contentStream) - { - Argument.AssertNotNull(contentStream, nameof(contentStream)); - - _contentStream = contentStream; - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - public IEnumerator GetEnumerator() - { - return new ServerSentEventEnumerator(_contentStream, this); - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - private sealed class ServerSentEventEnumerator : IEnumerator - { - private readonly ServerSentEventReader _reader; - private readonly ServerSentEventEnumerable _enumerable; - - public ServerSentEventEnumerator(Stream contentStream, ServerSentEventEnumerable enumerable) - { - _reader = new(contentStream); - _enumerable = enumerable; - } - - public ServerSentEvent Current { get; private set; } - - object IEnumerator.Current => Current; - - public bool MoveNext() - { - ServerSentEvent? nextEvent = _reader.TryGetNextEvent(); - _enumerable.LastEventId = _reader.LastEventId; - _enumerable.ReconnectionInterval= _reader.ReconnectionInterval; - - if (nextEvent.HasValue) - { - Current = nextEvent.Value; - return true; - } - - Current = default; - return false; - } - - public void Reset() - { - throw new NotSupportedException("Cannot seek back in an SSE stream."); - } - - public void Dispose() - { - // The creator of the enumerable has responsibility for disposing - // the content stream passed to the enumerable constructor. - } - } -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs deleted file mode 100644 index eaf72fb5121a..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace System.ClientModel.Internal; - -/// -/// Represents a field that can be composed into an SSE event. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal readonly struct ServerSentEventField -{ - private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); - private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); - private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); - private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); - - public ServerSentEventFieldKind FieldType { get; } - - // Note: we don't plan to expose UTF16 publicly - public ReadOnlyMemory Value { get; } - - internal ServerSentEventField(string line) - { - int colonIndex = line.AsSpan().IndexOf(':'); - - ReadOnlyMemory fieldName = colonIndex < 0 ? - line.AsMemory() : - line.AsMemory(0, colonIndex); - - FieldType = fieldName.Span switch - { - var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, - var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, - var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, - var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, - _ => ServerSentEventFieldKind.Ignore, - }; - - if (colonIndex < 0) - { - Value = ReadOnlyMemory.Empty; - } - else - { - Value = line.AsMemory(colonIndex + 1); - - // Per spec, remove a leading space if present. - if (Value.Length > 0 && Value.Span[0] == ' ') - { - Value = Value.Slice(1); - } - } - } -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs deleted file mode 100644 index 3ddc00aff270..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace System.ClientModel.Internal; - -/// -/// The kind of line or field received over an SSE stream. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal enum ServerSentEventFieldKind -{ - Ignore, - Event, - Data, - Id, - Retry, -} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs deleted file mode 100644 index e1e881743533..000000000000 --- a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace System.ClientModel.Internal; - -/// -/// An SSE event reader that reads lines from an SSE stream and composes them -/// into SSE events. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal sealed class ServerSentEventReader -{ - private readonly StreamReader _reader; - - public ServerSentEventReader(Stream stream) - { - Argument.AssertNotNull(stream, nameof(stream)); - - // The creator of the reader has responsibility for disposing the - // stream passed to the reader's constructor. - _reader = new StreamReader(stream); - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - /// - /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) - { - PendingEvent pending = default; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - // Note: would be nice to have polyfill that takes cancellation token, - // but may become moot if we shift to all UTF-8. - string? line = _reader.ReadLine(); - - if (line is null) - { - // A null line indicates end of input - return null; - } - - ProcessLine(line, ref pending, out bool dispatch); - - if (dispatch) - { - return pending.ToEvent(); - } - } - } - - /// - /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) - { - PendingEvent pending = default; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - // Note: would be nice to have polyfill that takes cancellation token, - // but may become moot if we shift to all UTF-8. - string? line = await _reader.ReadLineAsync().ConfigureAwait(false); - - if (line is null) - { - // A null line indicates end of input - return null; - } - - ProcessLine(line, ref pending, out bool dispatch); - - if (dispatch) - { - return pending.ToEvent(); - } - } - } - - private void ProcessLine(string line, ref PendingEvent pending, out bool dispatch) - { - dispatch = false; - - if (line.Length == 0) - { - if (pending.DataLength == 0) - { - // Per spec, if there's no data, don't dispatch an event. - pending = default; - } - else - { - dispatch = true; - } - } - else if (line[0] != ':') - { - // Per spec, ignore comment lines (i.e. that begin with ':'). - // If we got this far, process the field + value and accumulate - // it for the next dispatched event. - ServerSentEventField field = new(line); - switch (field.FieldType) - { - case ServerSentEventFieldKind.Event: - pending.EventTypeField = field; - break; - case ServerSentEventFieldKind.Data: - // Per spec, we'll append \n when we concatenate the data lines. - pending.DataLength += field.Value.Length + 1; - pending.DataFields.Add(field); - break; - case ServerSentEventFieldKind.Id: - LastEventId = field.Value.ToString(); - break; - case ServerSentEventFieldKind.Retry: - if (field.Value.Length > 0 && int.TryParse(field.Value.ToString(), out int retry)) - { - ReconnectionInterval = TimeSpan.FromMilliseconds(retry); - } - break; - default: - // Ignore - break; - } - } - } - - private struct PendingEvent - { - private const char LF = '\n'; - - private List? _dataFields; - - public int DataLength { get; set; } - public List DataFields => _dataFields ??= new(); - public ServerSentEventField? EventTypeField { get; set; } - - public ServerSentEvent ToEvent() - { - Debug.Assert(DataLength > 0); - - // Per spec, if event type buffer is empty, set event.type to "message". - string type = EventTypeField.HasValue ? - EventTypeField.Value.Value.ToString() : - "message"; - - Memory buffer = new(new char[DataLength]); - - int curr = 0; - foreach (ServerSentEventField field in DataFields) - { - Debug.Assert(field.FieldType == ServerSentEventFieldKind.Data); - - field.Value.Span.CopyTo(buffer.Span.Slice(curr)); - - // Per spec, append trailing LF to each data field value. - buffer.Span[curr + field.Value.Length] = LF; - curr += field.Value.Length + 1; - } - - // Per spec, remove trailing LF from concatenated data fields. - string data = buffer.Slice(0, buffer.Length - 1).ToString(); - - return new ServerSentEvent(type, data); - } - } -} diff --git a/sdk/core/System.ClientModel/tests/Convenience/PageCollectionScenarioTests.cs b/sdk/core/System.ClientModel/tests/Convenience/PageCollectionScenarioTests.cs deleted file mode 100644 index 1846346ed408..000000000000 --- a/sdk/core/System.ClientModel/tests/Convenience/PageCollectionScenarioTests.cs +++ /dev/null @@ -1,472 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using ClientModel.Tests.Mocks; -using ClientModel.Tests.Paging; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Results; - -/// -/// Scenario tests for sync and async page collections. -/// -public class PageScenarioCollectionTests -{ - [Test] - public void CanRehydratePageCollection() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - PagingClient client = new PagingClient(options); - PageCollection pages = client.GetValues(); - PageResult page = pages.GetCurrentPage(); - - ContinuationToken pageToken = page.PageToken; - - PageCollection rehydratedPages = client.GetValues(pageToken); - PageResult rehydratedPage = rehydratedPages.GetCurrentPage(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - List allValues = pages.GetAllValues().ToList(); - List allRehydratedValues = rehydratedPages.GetAllValues().ToList(); - - for (int i = 0; i < allValues.Count; i++) - { - Assert.AreEqual(allValues[i].Id, allRehydratedValues[i].Id); - } - } - - [Test] - public async Task CanRehydratePageCollectionAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - PagingClient client = new PagingClient(options); - AsyncPageCollection pages = client.GetValuesAsync(); - PageResult page = await pages.GetCurrentPageAsync(); - - ContinuationToken pageToken = page.PageToken; - - AsyncPageCollection rehydratedPages = client.GetValuesAsync(pageToken); - PageResult rehydratedPage = await rehydratedPages.GetCurrentPageAsync(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - List allValues = await pages.GetAllValuesAsync().ToListAsync(); - List allRehydratedValues = await rehydratedPages.GetAllValuesAsync().ToListAsync(); - - for (int i = 0; i < allValues.Count; i++) - { - Assert.AreEqual(allValues[i].Id, allRehydratedValues[i].Id); - } - } - - [Test] - public void CanReorderItemsAndRehydrate() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - string order = "desc"; - Assert.AreNotEqual(MockPagingData.DefaultOrder, order); - - PagingClient client = new PagingClient(options); - PageCollection pages = client.GetValues(order: order); - PageResult page = pages.GetCurrentPage(); - - ContinuationToken pageToken = page.PageToken; - - PageCollection rehydratedPages = client.GetValues(pageToken); - PageResult rehydratedPage = rehydratedPages.GetCurrentPage(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - // We got the last one first from both pages - Assert.AreEqual(MockPagingData.Count - 1, page.Values[0].Id); - Assert.AreEqual(MockPagingData.Count - 1, rehydratedPage.Values[0].Id); - } - - [Test] - public async Task CanReorderItemsAndRehydrateAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - string order = "desc"; - Assert.AreNotEqual(MockPagingData.DefaultOrder, order); - - PagingClient client = new PagingClient(options); - AsyncPageCollection pages = client.GetValuesAsync(order: order); - PageResult page = await pages.GetCurrentPageAsync(); - - ContinuationToken pageToken = page.PageToken; - - AsyncPageCollection rehydratedPages = client.GetValuesAsync(pageToken); - PageResult rehydratedPage = await rehydratedPages.GetCurrentPageAsync(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - // We got the last one first from both pages - Assert.AreEqual(MockPagingData.Count - 1, page.Values[0].Id); - Assert.AreEqual(MockPagingData.Count - 1, rehydratedPage.Values[0].Id); - } - - [Test] - public void CanChangePageSizeAndRehydrate() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - int pageSize = 4; - Assert.AreNotEqual(MockPagingData.DefaultPageSize, pageSize); - - PagingClient client = new PagingClient(options); - PageCollection pages = client.GetValues(pageSize: pageSize); - PageResult page = pages.GetCurrentPage(); - - ContinuationToken pageToken = page.PageToken; - - PageCollection rehydratedPages = client.GetValues(pageToken); - PageResult rehydratedPage = rehydratedPages.GetCurrentPage(); - - // Both pages have same non-default page size - Assert.AreEqual(pageSize, page.Values.Count); - Assert.AreEqual(pageSize, rehydratedPage.Values.Count); - } - - [Test] - public async Task CanChangePageSizeAndRehydrateAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - int pageSize = 4; - Assert.AreNotEqual(MockPagingData.DefaultPageSize, pageSize); - - PagingClient client = new PagingClient(options); - AsyncPageCollection pages = client.GetValuesAsync(pageSize: pageSize); - PageResult page = await pages.GetCurrentPageAsync(); - - ContinuationToken pageToken = page.PageToken; - - AsyncPageCollection rehydratedPages = client.GetValuesAsync(pageToken); - PageResult rehydratedPage = await rehydratedPages.GetCurrentPageAsync(); - - // Both pages have same non-default page size - Assert.AreEqual(pageSize, page.Values.Count); - Assert.AreEqual(pageSize, rehydratedPage.Values.Count); - } - - [Test] - public void CanSkipItemsAndRehydrate() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - int offset = 4; - Assert.AreNotEqual(MockPagingData.DefaultOffset, offset); - - PagingClient client = new PagingClient(options); - PageCollection pages = client.GetValues(offset: offset); - PageResult page = pages.GetCurrentPage(); - - ContinuationToken pageToken = page.PageToken; - - PageCollection rehydratedPages = client.GetValues(pageToken); - PageResult rehydratedPage = rehydratedPages.GetCurrentPage(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - // Both pages have the same non-default offset value - Assert.AreEqual(offset, page.Values[0].Id); - Assert.AreEqual(offset, rehydratedPage.Values[0].Id); - } - - [Test] - public async Task CanSkipItemsAndRehydrateAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - int offset = 4; - Assert.AreNotEqual(MockPagingData.DefaultOffset, offset); - - PagingClient client = new PagingClient(options); - AsyncPageCollection pages = client.GetValuesAsync(offset: offset); - PageResult page = await pages.GetCurrentPageAsync(); - - ContinuationToken pageToken = page.PageToken; - - AsyncPageCollection rehydratedPages = client.GetValuesAsync(pageToken); - PageResult rehydratedPage = await rehydratedPages.GetCurrentPageAsync(); - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - // Both pages have the same non-default offset value - Assert.AreEqual(offset, page.Values[0].Id); - Assert.AreEqual(offset, rehydratedPage.Values[0].Id); - } - - [Test] - public void CanChangeAllCollectionParametersAndRehydrate() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - string order = "desc"; - Assert.AreNotEqual(MockPagingData.DefaultOrder, order); - - int pageSize = 4; - Assert.AreNotEqual(MockPagingData.DefaultPageSize, pageSize); - - int offset = 4; - Assert.AreNotEqual(MockPagingData.DefaultOffset, offset); - - PagingClient client = new PagingClient(options); - PageCollection pages = client.GetValues(order, pageSize, offset); - PageResult page = pages.GetCurrentPage(); - - ContinuationToken pageToken = page.PageToken; - - PageCollection rehydratedPages = client.GetValues(pageToken); - PageResult rehydratedPage = rehydratedPages.GetCurrentPage(); - - // Both page collections and first pages are the same on each dimension - - // Collections have same non-default number of pages. - Assert.AreEqual(3, pages.Count()); - Assert.AreEqual(3, rehydratedPages.Count()); - - // Last one first and same items skipped - Assert.AreEqual(11, page.Values[0].Id); - Assert.AreEqual(11, rehydratedPage.Values[0].Id); - - // Equal page size - Assert.AreEqual(pageSize, page.Values.Count); - Assert.AreEqual(pageSize, rehydratedPage.Values.Count); - } - - [Test] - public async Task CanChangeAllCollectionParametersAndRehydrateAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - string order = "desc"; - Assert.AreNotEqual(MockPagingData.DefaultOrder, order); - - int pageSize = 4; - Assert.AreNotEqual(MockPagingData.DefaultPageSize, pageSize); - - int offset = 4; - Assert.AreNotEqual(MockPagingData.DefaultOffset, offset); - - PagingClient client = new PagingClient(options); - AsyncPageCollection pages = client.GetValuesAsync(order, pageSize, offset); - PageResult page = await pages.GetCurrentPageAsync(); - - ContinuationToken pageToken = page.PageToken; - - AsyncPageCollection rehydratedPages = client.GetValuesAsync(pageToken); - PageResult rehydratedPage = await rehydratedPages.GetCurrentPageAsync(); - - // Both page collections and first pages are the same on each dimension - - // Collections have same non-default number of pages. - Assert.AreEqual(3, await pages.CountAsync()); - Assert.AreEqual(3, await rehydratedPages.CountAsync()); - - // Last one first and same items skipped - Assert.AreEqual(11, page.Values[0].Id); - Assert.AreEqual(11, rehydratedPage.Values[0].Id); - - // Equal page size - Assert.AreEqual(pageSize, page.Values.Count); - Assert.AreEqual(pageSize, rehydratedPage.Values.Count); - } - - [Test] - public void CanCastToConvenienceFromProtocol() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - PagingClient client = new PagingClient(options); - - // Call the protocol method on the convenience client. - IEnumerable pageResults = client.GetValues( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - // Cast to convience type from protocol return value. - PageCollection pages = (PageCollection)pageResults; - - IEnumerable values = pages.GetAllValues(); - - int count = 0; - foreach (ValueItem value in values) - { - Assert.AreEqual(count, value.Id); - count++; - } - - Assert.AreEqual(MockPagingData.Count, count); - } - - [Test] - public async Task CanCastToConvenienceFromProtocolAsync() - { - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - PagingClient client = new PagingClient(options); - - // Call the protocol method on the convenience client. - IAsyncEnumerable pageResults = client.GetValuesAsync( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - // Cast to convience type from protocol return value. - AsyncPageCollection pages = (AsyncPageCollection)pageResults; - - IAsyncEnumerable values = pages.GetAllValuesAsync(); - - int count = 0; - await foreach (ValueItem value in values) - { - Assert.AreEqual(count, value.Id); - count++; - } - - Assert.AreEqual(MockPagingData.Count, count); - } - - [Test] - public void CanEvolveFromProtocol() - { - // This scenario tests validates that user code doesn't break when - // convenience methods are added. We show this by illustrating that - // exactly the same code works the same way when using a client that - // has only protocol methods and a client that has the same protocol - // methods and also convenience methods. - - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - static void Validate(IEnumerable results) - { - int pageCount = 0; - foreach (ClientResult result in results) - { - Assert.AreEqual(200, result.GetRawResponse().Status); - pageCount++; - } - - Assert.AreEqual(MockPagingData.Count / MockPagingData.DefaultPageSize, pageCount); - } - - // Protocol code - PagingProtocolClient protocolClient = new PagingProtocolClient(options); - IEnumerable pageResults = protocolClient.GetValues( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - Validate(pageResults); - - // Convenience code - PagingClient convenienceClient = new PagingClient(options); - IEnumerable pages = convenienceClient.GetValues( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - Validate(pages); - } - - [Test] - public async Task CanEvolveFromProtocolAsync() - { - // This scenario tests validates that user code doesn't break when - // convenience methods are added. We show this by illustrating that - // exactly the same code works the same way when using a client that - // has only protocol methods and a client that has the same protocol - // methods and also convenience methods. - - PagingClientOptions options = new() - { - Transport = new MockPipelineTransport("Mock", i => 200) - }; - - static async Task ValidateAsync(IAsyncEnumerable results) - { - int pageCount = 0; - await foreach (ClientResult result in results) - { - Assert.AreEqual(200, result.GetRawResponse().Status); - pageCount++; - } - - Assert.AreEqual(MockPagingData.Count / MockPagingData.DefaultPageSize, pageCount); - } - - // Protocol code - PagingProtocolClient protocolClient = new PagingProtocolClient(options); - IAsyncEnumerable pageResults = protocolClient.GetValuesAsync( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - await ValidateAsync(pageResults); - - // Convenience code - PagingClient convenienceClient = new PagingClient(options); - IAsyncEnumerable pages = convenienceClient.GetValuesAsync( - order: default, - pageSize: default, - offset: default, - new RequestOptions()); - - await ValidateAsync(pages); - } -} diff --git a/sdk/core/System.ClientModel/tests/Convenience/PageCollectionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/PageCollectionTests.cs deleted file mode 100644 index c29dd035d8e5..000000000000 --- a/sdk/core/System.ClientModel/tests/Convenience/PageCollectionTests.cs +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using ClientModel.Tests.Mocks; -using ClientModel.Tests.Paging; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Results; - -/// -/// Unit tests for sync and async page collections. -/// -public class PageCollectionTests -{ - private const int Count = 16; - private const int DefaultPageSize = 8; - private static readonly List MockValues = GetMockValues(Count).ToList(); - - private static IEnumerable GetMockValues(int count) - { - for (int i = 0; i < count; i++) - { - yield return i; - } - } - - [Test] - public void CanGetAllValues() - { - PageCollection pages = new MockPageCollection(MockValues, DefaultPageSize); - IEnumerable values = pages.GetAllValues(); - - int count = 0; - foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - } - - Assert.AreEqual(Count, count); - } - - [Test] - public async Task CanGetAllValuesAsync() - { - AsyncPageCollection pages = new MockAsyncPageCollection(MockValues, DefaultPageSize); - IAsyncEnumerable values = pages.GetAllValuesAsync(); - - int count = 0; - await foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - } - - Assert.AreEqual(Count, count); - } - - [Test] - public void CanGetCurrentPage() - { - PageCollection pages = new MockPageCollection(MockValues, DefaultPageSize); - PageResult page = pages.GetCurrentPage(); - - Assert.AreEqual(MockPagingData.DefaultPageSize, page.Values.Count); - Assert.AreEqual(0, page.Values[0]); - } - - [Test] - public async Task CanGetCurrentPageAsync() - { - AsyncPageCollection pages = new MockAsyncPageCollection(MockValues, DefaultPageSize); - PageResult page = await pages.GetCurrentPageAsync(); - - Assert.AreEqual(MockPagingData.DefaultPageSize, page.Values.Count); - Assert.AreEqual(0, page.Values[0]); - } - - [Test] - public void CanGetCurrentPageThenGetAllItems() - { - PageCollection pages = new MockPageCollection(MockValues, DefaultPageSize); - PageResult page = pages.GetCurrentPage(); - - Assert.AreEqual(DefaultPageSize, page.Values.Count); - Assert.AreEqual(0, page.Values[0]); - - IEnumerable values = pages.GetAllValues(); - - int count = 0; - foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - } - - Assert.AreEqual(Count, count); - } - - [Test] - public async Task CanGetCurrentPageThenGetAllItemsAsync() - { - AsyncPageCollection pages = new MockAsyncPageCollection(MockValues, DefaultPageSize); - PageResult page = await pages.GetCurrentPageAsync(); - - Assert.AreEqual(DefaultPageSize, page.Values.Count); - Assert.AreEqual(0, page.Values[0]); - - IAsyncEnumerable values = pages.GetAllValuesAsync(); - - int count = 0; - await foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - } - - Assert.AreEqual(Count, count); - } - - [Test] - public void CanGetCurrentPageWhileEnumeratingItems() - { - PageCollection pages = new MockPageCollection(MockValues, DefaultPageSize); - IEnumerable values = pages.GetAllValues(); - - int count = 0; - foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - - PageResult page = pages.GetCurrentPage(); - - // Validate that the current item is in range of the page values - Assert.GreaterOrEqual(value, page.Values[0]); - Assert.LessOrEqual(value, page.Values[page.Values.Count - 1]); - } - - Assert.AreEqual(MockPagingData.Count, count); - } - - [Test] - public async Task CanGetCurrentPageWhileEnumeratingItemsAsync() - { - AsyncPageCollection pages = new MockAsyncPageCollection(MockValues, DefaultPageSize); - IAsyncEnumerable values = pages.GetAllValuesAsync(); - - int count = 0; - await foreach (int value in values) - { - Assert.AreEqual(count, value); - count++; - - PageResult page = await pages.GetCurrentPageAsync(); - - // Validate that the current item is in range of the page values - Assert.GreaterOrEqual(value, page.Values[0]); - Assert.LessOrEqual(value, page.Values[page.Values.Count - 1]); - } - - Assert.AreEqual(MockPagingData.Count, count); - } - - [Test] - public void CanEnumerateClientResults() - { - PageCollection pages = new MockPageCollection(MockValues, DefaultPageSize); - IEnumerable pageResults = pages; - - int pageCount = 0; - foreach (ClientResult result in pageResults) - { - Assert.AreEqual(200, result.GetRawResponse().Status); - pageCount++; - } - - Assert.AreEqual(2, pageCount); - } - - [Test] - public async Task CanEnumerateClientResultsAsync() - { - AsyncPageCollection pages = new MockAsyncPageCollection(MockValues, DefaultPageSize); - IAsyncEnumerable pageResults = pages; - - int pageCount = 0; - await foreach (ClientResult result in pageResults) - { - Assert.AreEqual(200, result.GetRawResponse().Status); - pageCount++; - } - - Assert.AreEqual(2, pageCount); - } -} diff --git a/sdk/core/System.ClientModel/tests/Convenience/PaginatedCollectionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/PaginatedCollectionTests.cs new file mode 100644 index 000000000000..36c6a4d18f2f --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Convenience/PaginatedCollectionTests.cs @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Collections; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Results; + +/// +/// Scenario tests for sync and async paginated collections. +/// These tests use a reference implementation of a client that calls paginated +/// service endpoints. +/// +public class PaginatedCollectionTests +{ + // Tests: + // 1. Protocol/Sync + // a. Can enumerate pages + // b. Can rehydrate from token + // 2. Protocol/Async + // a. Can enumerate pages + // b. Can rehydrate from token + // 3. Convenience/Sync + // a. Can enumerate Ts + // b. Can cancel with single cancellation token + // c. Can evolve from protocol + // 4. Convenience/Async + // a. Can enumerate Ts + // b. Can cancel with either of two cancellation tokens + // c. Can evolve from protocol + + [Test] + public void CanEnumerateRawPages() + { + ProtocolPaginatedCollectionClient client = new(); + + CollectionResult valueCollection = client.GetValues(); + IEnumerable pages = valueCollection.GetRawPages(); + + int expectedValueId = 0; + int pageCount = 0; + foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage conveniencePage = ValueItemPage.FromJson(response.Content); + + Assert.AreEqual(MockPageResponseData.DefaultPageSize, conveniencePage.Values.Count); + Assert.AreEqual(expectedValueId, conveniencePage.Values[0].Id); + + pageCount++; + expectedValueId += MockPageResponseData.DefaultPageSize; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize, pageCount); + } + + [Test] + public void CanRehydrateCollection() + { + ProtocolPaginatedCollectionClient client = new(); + + CollectionResult valueCollection = client.GetValues(); + List pages = valueCollection.GetRawPages().ToList(); + ClientResult firstPage = pages[0]; + + ContinuationToken? nextPageToken = valueCollection.GetContinuationToken(firstPage); + CollectionResult rehydratedCollection = client.GetValues(nextPageToken!); + + List rehydratedPages = rehydratedCollection.GetRawPages().ToList(); + + int totalPageCount = MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize; + int rehydratedPageCount = 0; + for (int i = 1; i < totalPageCount; i++) + { + ClientResult originalPageResult = pages[i]; + ClientResult rehydratedPageResult = rehydratedPages[i - 1]; + + ValueItemPage originalPage = ValueItemPage.FromJson(originalPageResult.GetRawResponse().Content); + ValueItemPage rehydratedPage = ValueItemPage.FromJson(rehydratedPageResult.GetRawResponse().Content); + + Assert.AreEqual(originalPage.Values.Count, rehydratedPage.Values.Count); + Assert.AreEqual(originalPage.Values[0].Id, rehydratedPage.Values[0].Id); + + rehydratedPageCount++; + } + + Assert.AreEqual(totalPageCount - 1, rehydratedPageCount); + } + + [Test] + public async Task CanEnumerateRawPagesAsync() + { + ProtocolPaginatedCollectionClient client = new(); + + AsyncCollectionResult valueCollection = client.GetValuesAsync(); + IAsyncEnumerable pages = valueCollection.GetRawPagesAsync(); + + int expectedValueId = 0; + int pageCount = 0; + await foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage conveniencePage = ValueItemPage.FromJson(response.Content); + + Assert.AreEqual(MockPageResponseData.DefaultPageSize, conveniencePage.Values.Count); + Assert.AreEqual(expectedValueId, conveniencePage.Values[0].Id); + + pageCount++; + expectedValueId += MockPageResponseData.DefaultPageSize; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize, pageCount); + } + + [Test] + public async Task CanRehydrateCollectionAsync() + { + ProtocolPaginatedCollectionClient client = new(); + + AsyncCollectionResult valueCollection = client.GetValuesAsync(); + List pages = await valueCollection.GetRawPagesAsync().ToListAsync(); + ClientResult firstPage = pages[0]; + + ContinuationToken? nextPageToken = valueCollection.GetContinuationToken(firstPage); + AsyncCollectionResult rehydratedCollection = client.GetValuesAsync(nextPageToken!); + + List rehydratedPages = await rehydratedCollection.GetRawPagesAsync().ToListAsync(); + + int totalPageCount = MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize; + int rehydratedPageCount = 0; + for (int i = 1; i < totalPageCount; i++) + { + ClientResult originalPageResult = pages[i]; + ClientResult rehydratedPageResult = rehydratedPages[i - 1]; + + ValueItemPage originalPage = ValueItemPage.FromJson(originalPageResult.GetRawResponse().Content); + ValueItemPage rehydratedPage = ValueItemPage.FromJson(rehydratedPageResult.GetRawResponse().Content); + + Assert.AreEqual(originalPage.Values.Count, rehydratedPage.Values.Count); + Assert.AreEqual(originalPage.Values[0].Id, rehydratedPage.Values[0].Id); + + rehydratedPageCount++; + } + + Assert.AreEqual(totalPageCount - 1, rehydratedPageCount); + } + + [Test] + public void CanEnumerateValues() + { + PaginatedCollectionClient client = new(); + CollectionResult values = client.GetValues(); + + int count = 0; + foreach (ValueItem value in values) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount, count); + } + + [Test] + public void CanCancelViaServiceMethodCancellationToken() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + PaginatedCollectionClient client = new(); + CollectionResult values = client.GetValues(cancellationToken: cts.Token); + + Assert.Throws(() => values.First()); + } + + [Test] + public void CanEvolveFromProtocolLayer() + { + // This tests validates that user code doesn't break when convenience + // methods are added. We show this by illustrating that code written + // at the protocol layer continues to work the same way when using a + // client that has only protocol methods and when using client that has + // both convenience and protocol methods. + + static bool Validate(CollectionResult valueCollection) + { + IEnumerable pages = valueCollection.GetRawPages(); + + int expectedValueId = 0; + int pageCount = 0; + foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage conveniencePage = ValueItemPage.FromJson(response.Content); + + Assert.AreEqual(MockPageResponseData.DefaultPageSize, conveniencePage.Values.Count); + Assert.AreEqual(expectedValueId, conveniencePage.Values[0].Id); + + pageCount++; + expectedValueId += MockPageResponseData.DefaultPageSize; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize, pageCount); + return true; + } + + // Protocol client (v1) code + ProtocolPaginatedCollectionClient protocolClient = new(); + CollectionResult protocolCollection = protocolClient.GetValues(); + + // Convenience client (v2) code + PaginatedCollectionClient convenienceClient = new(); + CollectionResult convenienceCollection = convenienceClient.GetValues(); + + Assert.IsTrue(Validate(protocolCollection)); + Assert.IsTrue(Validate(convenienceCollection)); + } + + [Test] + public void CanCastFromProtocolToConvenienceReturnType() + { + PaginatedCollectionClient client = new(); + + // Call protocol method + CollectionResult protocolCollection = client.GetValues(pageSize: default, new RequestOptions()); + + // Cast to convenience method + CollectionResult convenienceCollection = (CollectionResult)protocolCollection; + + int count = 0; + foreach (ValueItem value in convenienceCollection) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount, count); + } + + [Test] + public async Task CanEnumerateValuesAsync() + { + PaginatedCollectionClient client = new(); + AsyncCollectionResult values = client.GetValuesAsync(); + + int count = 0; + await foreach (ValueItem value in values) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount, count); + } + + [Test] + public void CanCancelViaServiceMethodCancellationTokenAsync() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + PaginatedCollectionClient client = new(); + AsyncCollectionResult values = client.GetValuesAsync(cancellationToken: cts.Token); + + Assert.ThrowsAsync(async () => await values.FirstAsync()); + } + + [Test] + public async Task CanCancelViaAsyncEnumerableCancellationTokenAsync() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + PaginatedCollectionClient client = new(); + AsyncCollectionResult values = client.GetValuesAsync(); + + bool threwException = false; + try + { + await foreach (ValueItem value in values.WithCancellation(cts.Token)) + { + } + } + catch (OperationCanceledException) + { + threwException = true; + } + + Assert.IsTrue(threwException); + } + + [Test] + public async Task CanEvolveFromProtocolLayerAsync() + { + // This tests validates that user code doesn't break when convenience + // methods are added. We show this by illustrating that code written + // at the protocol layer continues to work the same way when using a + // client that has only protocol methods and when using client that has + // both convenience and protocol methods. + + static async Task ValidateAsync(AsyncCollectionResult valueCollection) + { + IAsyncEnumerable pages = valueCollection.GetRawPagesAsync(); + + int expectedValueId = 0; + int pageCount = 0; + await foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage conveniencePage = ValueItemPage.FromJson(response.Content); + + Assert.AreEqual(MockPageResponseData.DefaultPageSize, conveniencePage.Values.Count); + Assert.AreEqual(expectedValueId, conveniencePage.Values[0].Id); + + pageCount++; + expectedValueId += MockPageResponseData.DefaultPageSize; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize, pageCount); + return true; + } + + // Protocol client (v1) code + ProtocolPaginatedCollectionClient protocolClient = new(); + AsyncCollectionResult protocolCollection = protocolClient.GetValuesAsync(); + + // Convenience client (v2) code + PaginatedCollectionClient convenienceClient = new(); + AsyncCollectionResult convenienceCollection = convenienceClient.GetValuesAsync(); + + Assert.IsTrue(await ValidateAsync(protocolCollection)); + Assert.IsTrue(await ValidateAsync(convenienceCollection)); + } + + [Test] + public async Task CanCastFromProtocolToConvenienceReturnTypeAsync() + { + PaginatedCollectionClient client = new(); + + // Call protocol method + AsyncCollectionResult protocolCollection = client.GetValuesAsync(pageSize: default, new RequestOptions()); + + // Cast to convenience method + AsyncCollectionResult convenienceCollection = (AsyncCollectionResult)protocolCollection; + + int count = 0; + await foreach (ValueItem value in convenienceCollection) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockPageResponseData.TotalItemCount, count); + } + + [Test] + public async Task CanGetDataInPagesFromTestService() + { + List pages = MockPageResponseData.GetPages().ToList(); + int pageIndex = 0; + + using TestServer testServer = new( + async context => + { + ValueItemPage page = pages[pageIndex++]; + byte[] content = page.ToJson().ToArray(); + + context.Response.StatusCode = 200; + await context.Response.Body.WriteAsync(content, 0, content.Length); + }); + + ClientPipeline pipeline = ClientPipeline.Create(); + + int pageCount = 0; + ValueItemPage valuePage = default!; + do + { + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + + await pipeline.SendAsync(message); + + PipelineResponse response = message.Response!; + valuePage = ValueItemPage.FromJson(response.Content); + + Assert.AreEqual(MockPageResponseData.DefaultPageSize, valuePage.Values.Count); + + pageCount++; + } + while (valuePage.HasMore); + + Assert.AreEqual(MockPageResponseData.TotalItemCount / MockPageResponseData.DefaultPageSize, pageCount); + } +} diff --git a/sdk/core/System.ClientModel/tests/Convenience/StreamedCollectionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/StreamedCollectionTests.cs new file mode 100644 index 000000000000..00743797c06f --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Convenience/StreamedCollectionTests.cs @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using ClientModel.Tests.Collections; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Results; + +/// +/// Scenario tests for sync and async streamed collections. +/// These tests use a reference implementation of a client that calls streaming +/// service endpoints. +/// +public class StreamedCollectionTests +{ + // Tests: + // 1. Protocol/Sync + // a. Can enumerate pages (only one response for now) + // 2. Protocol/Async + // a. Can enumerate pages + // 3. Convenience/Sync + // a. Can get values from response stream + // b. Response stream is disposed + // 4. Convenience/Async + // a. Can get values from response stream + // b. Response stream is disposed + + [Test] + public void CanEnumerateRawPages() + { + StreamedCollectionClient client = new(); + CollectionResult collection = client.GetValues(); + IEnumerable pages = collection.GetRawPages(); + + int pageCount = 0; + foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + + Assert.AreEqual(200, response.Status); + Assert.IsTrue(response.Content.ToString().StartsWith("event")); + + pageCount++; + } + + Assert.AreEqual(1, pageCount); + } + + [Test] + public async Task CanEnumerateRawPagesAsync() + { + StreamedCollectionClient client = new(); + AsyncCollectionResult collection = client.GetValuesAsync(); + IAsyncEnumerable pages = collection.GetRawPagesAsync(); + + int pageCount = 0; + await foreach (ClientResult page in pages) + { + PipelineResponse response = page.GetRawResponse(); + + Assert.AreEqual(200, response.Status); + Assert.IsTrue(response.Content.ToString().StartsWith("event")); + + pageCount++; + } + + Assert.AreEqual(1, pageCount); + } + + [Test] + public void CanEnumerateValues() + { + StreamedCollectionClient client = new(); + CollectionResult values = client.GetValues(); + + int count = 0; + foreach (StreamedValue value in values) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockStreamedData.TotalItemCount, count); + } + + [Test] + public void ResponseStreamIsDisposed() + { + StreamedCollectionClient client = new(); + StreamedValueCollectionResult? values = client.GetValues() as StreamedValueCollectionResult; + + Assert.IsNotNull(values); + + ClientResult page = values!.GetRawPages().First(); + MockStreamedResponse? response = page.GetRawResponse() as MockStreamedResponse; + + Assert.IsNotNull(response); + Assert.IsFalse(response?.IsDisposed); + + int count = 0; + foreach (StreamedValue value in values!.GetPageValues(page)) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.IsTrue(response?.IsDisposed); + } + + [Test] + public async Task CanEnumerateValuesAsync() + { + StreamedCollectionClient client = new(); + AsyncCollectionResult values = client.GetValuesAsync(); + + int count = 0; + await foreach (StreamedValue value in values) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.AreEqual(MockStreamedData.TotalItemCount, count); + } + + [Test] + public async Task ResponseStreamIsDisposedAsync() + { + StreamedCollectionClient client = new(); + AsyncStreamedValueCollectionResult? values = client.GetValuesAsync() as AsyncStreamedValueCollectionResult; + + Assert.IsNotNull(values); + + ClientResult page = await values!.GetRawPagesAsync().FirstAsync(); + MockStreamedResponse? response = page.GetRawResponse() as MockStreamedResponse; + + Assert.IsNotNull(response); + Assert.IsFalse(response?.IsDisposed); + + int count = 0; + await foreach (StreamedValue value in values!.GetPageValuesAsync(page)) + { + Assert.AreEqual(count, value.Id); + count++; + } + + Assert.IsTrue(response?.IsDisposed); + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockAsyncPageCollection.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockAsyncPageCollection.cs deleted file mode 100644 index 4d2b9a7d5376..000000000000 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockAsyncPageCollection.cs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Mocks; - -public class MockAsyncPageCollection : AsyncPageCollection -{ - private readonly List _values; - private readonly int _pageSize; - - private int _current; - - public MockAsyncPageCollection(List values, int pageSize) - { - _values = values; - _pageSize = pageSize; - } - - protected override async Task> GetCurrentPageAsyncCore() - => await GetPageFromCurrentStateAsync().ConfigureAwait(false); - - protected override async IAsyncEnumerator> GetAsyncEnumeratorCore(CancellationToken cancellationToken) - { - while (_current < _values.Count) - { - yield return await GetPageFromCurrentStateAsync().ConfigureAwait(false); - - _current += _pageSize; - } - } - - private async Task> GetPageFromCurrentStateAsync() - { - await Task.Delay(0); - - int pageSize = Math.Min(_pageSize, _values.Count - _current); - List pageValues = _values.GetRange(_current, pageSize); - - // Make page tokens not useful for mocks. - ContinuationToken mockPageToken = ContinuationToken.FromBytes(BinaryData.FromString("{}")); - return PageResult.Create(pageValues, mockPageToken, null, new MockPipelineResponse(200)); - } -} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageCollection.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageCollection.cs deleted file mode 100644 index 47fa94c12c16..000000000000 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageCollection.cs +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.Collections.Generic; - -namespace ClientModel.Tests.Mocks; - -public class MockPageCollection : PageCollection -{ - private readonly List _values; - private readonly int _pageSize; - - private int _current; - - public MockPageCollection(List values, int pageSize) - { - _values = values; - _pageSize = pageSize; - } - - protected override PageResult GetCurrentPageCore() - => GetPageFromCurrentState(); - - protected override IEnumerator> GetEnumeratorCore() - { - while (_current < _values.Count) - { - yield return GetPageFromCurrentState(); - - _current += _pageSize; - } - } - - private PageResult GetPageFromCurrentState() - { - int pageSize = Math.Min(_pageSize, _values.Count - _current); - List pageValues = _values.GetRange(_current, pageSize); - - // Make page tokens not useful for mocks. - ContinuationToken mockPageToken = ContinuationToken.FromBytes(BinaryData.FromString("{}")); - return PageResult.Create(pageValues, mockPageToken, null, new MockPipelineResponse(200)); - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncProtocolValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncProtocolValueCollectionResult.cs new file mode 100644 index 000000000000..e9f9d32b27cb --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncProtocolValueCollectionResult.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Collections; + +internal class AsyncProtocolValueCollectionResult : AsyncCollectionResult +{ + private readonly IEnumerable _mockPagesData; + + private readonly int? _pageSize; + private readonly int? _offset; + private readonly RequestOptions? _options; + private readonly CancellationToken _cancellationToken; + + public AsyncProtocolValueCollectionResult(int? pageSize, int? offset, RequestOptions? options) + { + _pageSize = pageSize; + _offset = offset; + _options = options; + _cancellationToken = _options?.CancellationToken ?? default; + + _mockPagesData = MockPageResponseData.GetPages(pageSize, offset); + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + => ValueCollectionPageToken.FromResponse(page, _pageSize); + + public override async IAsyncEnumerable GetRawPagesAsync() + { + foreach (ValueItemPage page in _mockPagesData) + { + await Task.Delay(0, _cancellationToken).ConfigureAwait(false); + + PipelineResponse response = new MockPageResponse(page); + yield return ClientResult.FromResponse(response); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncValueCollectionResult.cs new file mode 100644 index 000000000000..8268b45bea56 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/AsyncValueCollectionResult.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Collections; + +internal class AsyncValueCollectionResult : AsyncCollectionResult +{ + private readonly IEnumerable _mockPagesData; + + private readonly int? _pageSize; + private readonly int? _offset; + private readonly RequestOptions? _options; + private readonly CancellationToken _cancellationToken; + + public AsyncValueCollectionResult(int? pageSize, int? offset, RequestOptions? options) + { + _pageSize = pageSize; + _offset = offset; + _options = options; + _cancellationToken = _options?.CancellationToken ?? default; + + _mockPagesData = MockPageResponseData.GetPages(pageSize, offset); + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + => ValueCollectionPageToken.FromResponse(page, _pageSize); + + public override async IAsyncEnumerable GetRawPagesAsync() + { + foreach (ValueItemPage page in _mockPagesData) + { + await Task.Delay(0, _cancellationToken).ConfigureAwait(false); + + PipelineResponse response = new MockPageResponse(page); + yield return ClientResult.FromResponse(response); + } + } + + protected override IAsyncEnumerable GetValuesFromPageAsync(ClientResult page) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage valuePage = ValueItemPage.FromJson(response.Content); + return valuePage.Values.ToAsyncEnumerable(_cancellationToken); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/Argument.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/Argument.cs similarity index 76% rename from sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/Argument.cs rename to sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/Argument.cs index 79804fd51ac1..98c095c564dd 100644 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/Argument.cs +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/Argument.cs @@ -2,10 +2,8 @@ // Licensed under the MIT License. using System; -using System.Collections; -using System.Collections.Generic; -namespace ClientModel.Tests.Paging; +namespace ClientModel.Tests.Collections; internal static class Argument { diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/CancellationTokenExtensions.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/CancellationTokenExtensions.cs similarity index 91% rename from sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/CancellationTokenExtensions.cs rename to sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/CancellationTokenExtensions.cs index 78058b617a06..f58e676473b0 100644 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/CancellationTokenExtensions.cs +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/CancellationTokenExtensions.cs @@ -4,7 +4,7 @@ using System.ClientModel.Primitives; using System.Threading; -namespace ClientModel.Tests.Paging; +namespace ClientModel.Tests.Collections; internal static class CancellationTokenExtensions { diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/IEnumerableExtensions.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/IEnumerableExtensions.cs new file mode 100644 index 000000000000..134741e83970 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/Emitted/IEnumerableExtensions.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Collections; + +internal static class IEnumerableExtensions +{ + public static async IAsyncEnumerable ToAsyncEnumerable(this IEnumerable enumerable, [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (T item in enumerable) + { + await Task.Delay(0, cancellationToken).ConfigureAwait(false); + yield return item; + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockValueItemPageResponse.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponse.cs similarity index 51% rename from sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockValueItemPageResponse.cs rename to sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponse.cs index f21cb0720d26..425fdd4ce560 100644 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockValueItemPageResponse.cs +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponse.cs @@ -3,35 +3,17 @@ using System; using System.ClientModel.Primitives; -using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; -namespace ClientModel.Tests.Paging; +namespace ClientModel.Tests.Collections; -internal class MockValueItemPageResponse : PipelineResponse +internal class MockPageResponse : PipelineResponse { - public MockValueItemPageResponse(IEnumerable values) + public MockPageResponse(ValueItemPage page) { - StringBuilder sb = new StringBuilder(); - sb.AppendLine("["); - - int count = 0; - foreach (ValueItem value in values) - { - sb.AppendLine(value.ToJson()); - - if (++count != values.Count()) - { - sb.AppendLine(","); - } - } - sb.AppendLine("]"); - - Content = BinaryData.FromString(sb.ToString()); + Content = page.ToJson(); } public override int Status => 200; @@ -46,7 +28,8 @@ public override Stream? ContentStream public override BinaryData Content { get; } - protected override PipelineResponseHeaders HeadersCore => throw new NotImplementedException(); + protected override PipelineResponseHeaders HeadersCore + => throw new NotImplementedException(); public override BinaryData BufferContent(CancellationToken cancellationToken = default) => Content; diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponseData.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponseData.cs new file mode 100644 index 000000000000..1b071cdeaeeb --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/MockPageResponseData.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; + +namespace ClientModel.Tests.Collections; + +public class MockPageResponseData +{ + public const int TotalItemCount = 16; + + public const int DefaultPageSize = 4; + public const int DefaultOffset = 0; + + // Source of all the data + internal static IEnumerable GetAllValues() + { + for (int i = 0; i < TotalItemCount; i++) + { + yield return new ValueItem(i, $"{i}"); + } + } + + public static IEnumerable GetPages(int? pageSize = default, + int? offset = default) + { + pageSize ??= DefaultPageSize; + offset ??= DefaultOffset; + + IEnumerable valueSource = GetAllValues(); + + for (int i = offset.Value; i < TotalItemCount;) + { + IEnumerable pageItems = valueSource.Skip(i).Take(pageSize.Value); + i += pageSize.Value; + yield return new ValueItemPage(pageItems, i < TotalItemCount); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/ValueItemPage.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/ValueItemPage.cs new file mode 100644 index 000000000000..650a87803609 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/MockData/ValueItemPage.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; + +namespace ClientModel.Tests.Collections; + +// public as a convience for tests. +public class ValueItemPage +{ + public ValueItemPage(IEnumerable values, bool hasMore) + { + Values = values.ToList().AsReadOnly(); + HasMore = hasMore; + } + + public IReadOnlyList Values { get; } + + public bool HasMore { get; } + + public BinaryData ToJson() + { + StringBuilder sb = new StringBuilder(); + sb.AppendLine("{"); + sb.AppendLine("\"data\":"); + sb.AppendLine("["); + + int count = 0; + foreach (ValueItem value in Values) + { + sb.AppendLine(value.ToJson()); + + if (++count != Values.Count) + { + sb.AppendLine(","); + } + } + sb.AppendLine("],"); + sb.AppendLine($"\"has_more\": {HasMore.ToString().ToLower()}"); + sb.AppendLine("}"); + + return BinaryData.FromString(sb.ToString()); + } + + public static ValueItemPage FromJson(BinaryData json) + { + List items = new(); + + using JsonDocument doc = JsonDocument.Parse(json); + JsonElement data = doc.RootElement.GetProperty("data"); + foreach (JsonElement item in data.EnumerateArray()) + { + items.Add(ValueItem.FromJson(item)); + } + JsonElement more = doc.RootElement.GetProperty("has_more"); + bool hasMore = more.GetBoolean(); + + return new ValueItemPage(items, hasMore); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClient.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClient.cs new file mode 100644 index 000000000000..dc81f8acc588 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClient.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Threading; + +namespace ClientModel.Tests.Collections; + +// A reference implementation that illustrates client patterns for paginated +// service endpoints for clients that have both convenience and protocol methods. +public class PaginatedCollectionClient +{ + public PaginatedCollectionClient(PaginatedCollectionClientOptions? options = default) + { + } + + public virtual AsyncCollectionResult GetValuesAsync( + int? pageSize = default, + CancellationToken cancellationToken = default) + { + return new AsyncValueCollectionResult(pageSize, offset: default, cancellationToken.ToRequestOptions()); + } + + public virtual AsyncCollectionResult GetValuesAsync( + ContinuationToken continuationToken, + CancellationToken cancellationToken = default) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new AsyncValueCollectionResult(token.PageSize, token.Offset, cancellationToken.ToRequestOptions()); + } + public virtual CollectionResult GetValues( + int? pageSize = default, + CancellationToken cancellationToken = default) + { + return new ValueCollectionResult(pageSize, offset: default, cancellationToken.ToRequestOptions()); + } + + public virtual CollectionResult GetValues( + ContinuationToken continuationToken, + CancellationToken cancellationToken = default) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new ValueCollectionResult(token.PageSize, token.Offset, cancellationToken.ToRequestOptions()); + } + public virtual AsyncCollectionResult GetValuesAsync( + int? pageSize, + RequestOptions? options) + { + return new AsyncValueCollectionResult(pageSize, offset: default, options); + } + + public virtual AsyncCollectionResult GetValuesAsync( + ContinuationToken continuationToken, + RequestOptions? options) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new AsyncValueCollectionResult(token.PageSize, token.Offset, options); + } + public virtual CollectionResult GetValues( + int? pageSize, + RequestOptions? options) + { + return new ValueCollectionResult(pageSize, offset: default, options); + } + + public virtual CollectionResult GetValues( + ContinuationToken continuationToken, + RequestOptions? options) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new ValueCollectionResult(token.PageSize, token.Offset, options); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClientOptions.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClientOptions.cs new file mode 100644 index 000000000000..e3d8a68c9939 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/PaginatedCollectionClientOptions.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace ClientModel.Tests.Collections; + +public class PaginatedCollectionClientOptions : ClientPipelineOptions +{ +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolPaginatedCollectionClient.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolPaginatedCollectionClient.cs new file mode 100644 index 000000000000..b1a372211948 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolPaginatedCollectionClient.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace ClientModel.Tests.Collections; + +// A reference implementation that illustrates client patterns for paginated +// service endpoints for clients that have only protocol methods. +public class ProtocolPaginatedCollectionClient +{ + public ProtocolPaginatedCollectionClient(PaginatedCollectionClientOptions? options = default) + { + } + + public virtual AsyncCollectionResult GetValuesAsync( + int? pageSize = default, + RequestOptions? options = default) + { + return new AsyncProtocolValueCollectionResult(pageSize, offset: default, options); + } + public virtual AsyncCollectionResult GetValuesAsync( + ContinuationToken continuationToken, + RequestOptions? options = default) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new AsyncProtocolValueCollectionResult(token.PageSize, token.Offset, options); + } + public virtual CollectionResult GetValues( + int? pageSize = default, + RequestOptions? options = default) + { + return new ProtocolValueCollectionResult(pageSize, offset: default, options); + } + public virtual CollectionResult GetValues( + ContinuationToken continuationToken, + RequestOptions? options = default) + { + ValueCollectionPageToken token = ValueCollectionPageToken.FromToken(continuationToken); + + return new ProtocolValueCollectionResult(token.PageSize, token.Offset, options); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolValueCollectionResult.cs new file mode 100644 index 000000000000..1a6221e0efc8 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ProtocolValueCollectionResult.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; + +namespace ClientModel.Tests.Collections; + +/// +/// Protocol-layer paginated collection +/// +internal class ProtocolValueCollectionResult : CollectionResult +{ + private readonly IEnumerable _mockPagesData; + + private readonly int? _pageSize; + private readonly int? _offset; + private readonly RequestOptions? _options; + private readonly CancellationToken _cancellationToken; + + public ProtocolValueCollectionResult(int? pageSize, int? offset, RequestOptions? options) + { + _pageSize = pageSize; + _offset = offset; + _options = options; + _cancellationToken = _options?.CancellationToken ?? default; + + _mockPagesData = MockPageResponseData.GetPages(pageSize, offset); + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + => ValueCollectionPageToken.FromResponse(page, _pageSize); + + public override IEnumerable GetRawPages() + { + foreach (ValueItemPage page in _mockPagesData) + { + // Simulate the pipeline checking for cancellation, + // which happens in the transport + _cancellationToken.ThrowIfCancellationRequested(); + + PipelineResponse response = new MockPageResponse(page); + yield return ClientResult.FromResponse(response); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageToken.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionPageToken.cs similarity index 62% rename from sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageToken.cs rename to sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionPageToken.cs index 3a205541eadc..fca3e4ad7281 100644 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageToken.cs +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionPageToken.cs @@ -3,23 +3,24 @@ using System; using System.ClientModel; +using System.ClientModel.Primitives; using System.Diagnostics; using System.IO; +using System.Linq; using System.Text.Json; -namespace ClientModel.Tests.Paging; +namespace ClientModel.Tests.Collections; -internal class ValuesPageToken : ContinuationToken +internal class ValueCollectionPageToken : ContinuationToken { - protected ValuesPageToken(string? order, int? pageSize, int? offset) + protected ValueCollectionPageToken(int? pageSize, int? offset) { - Order = order; PageSize = pageSize; Offset = offset; } - public string? Order { get; } public int? PageSize { get; } + public int? Offset { get; } public override BinaryData ToBytes() @@ -29,11 +30,6 @@ public override BinaryData ToBytes() writer.WriteStartObject(); - if (Order is not null) - { - writer.WriteString("order", Order); - } - if (PageSize.HasValue) { writer.WriteNumber("pageSize", PageSize.Value); @@ -52,19 +48,19 @@ public override BinaryData ToBytes() return BinaryData.FromStream(stream); } - public ValuesPageToken? GetNextPageToken(int offset, int count) + public ValueCollectionPageToken? GetNextPageToken(int offset, int count) { if (offset >= count) { return null; } - return new ValuesPageToken(Order, PageSize, offset); + return new ValueCollectionPageToken(PageSize, offset); } - public static ValuesPageToken FromToken(ContinuationToken pageToken) + public static ValueCollectionPageToken FromToken(ContinuationToken pageToken) { - if (pageToken is ValuesPageToken token) + if (pageToken is ValueCollectionPageToken token) { return token; } @@ -73,12 +69,11 @@ public static ValuesPageToken FromToken(ContinuationToken pageToken) if (data.ToMemory().Length == 0) { - throw new ArgumentException("Failed to create ValuesPageToken from provided pageToken.", nameof(pageToken)); + throw new ArgumentException("Failed to create ValueCollectionPageToken from provided pageToken.", nameof(pageToken)); } Utf8JsonReader reader = new(data); - string? order = null; int? pageSize = null; int? offset = null; @@ -99,12 +94,6 @@ public static ValuesPageToken FromToken(ContinuationToken pageToken) switch (propertyName) { - case "order": - reader.Read(); - Debug.Assert(reader.TokenType == JsonTokenType.String); - order = reader.GetString(); - break; - case "pageSize": reader.Read(); Debug.Assert(reader.TokenType == JsonTokenType.Number); @@ -121,9 +110,27 @@ public static ValuesPageToken FromToken(ContinuationToken pageToken) } } - return new(order, pageSize, offset); + return new(pageSize, offset); } - public static ValuesPageToken FromOptions(string? order, int? pageSize, int? offset) - => new(order, pageSize, offset); + public static ValueCollectionPageToken FromOptions(int? pageSize, int? offset) + => new(pageSize, offset); + + public static ValueCollectionPageToken? FromResponse(ClientResult page, int? pageSize) + { + PipelineResponse response = page.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + + JsonElement data = doc.RootElement.GetProperty("data"); + int lastId = data.EnumerateArray().LastOrDefault().GetProperty("id").GetInt32(); + bool hasMore = doc.RootElement.GetProperty("has_more"u8).GetBoolean(); + + if (!hasMore) + { + return null; + } + + return new(pageSize, lastId + 1); + } } diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionResult.cs new file mode 100644 index 000000000000..69777cd0bd24 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueCollectionResult.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; + +namespace ClientModel.Tests.Collections; + +internal class ValueCollectionResult : CollectionResult +{ + private readonly IEnumerable _mockPagesData; + + private readonly int? _pageSize; + private readonly int? _offset; + private readonly RequestOptions? _options; + private readonly CancellationToken _cancellationToken; + + public ValueCollectionResult(int? pageSize, int? offset, RequestOptions? options) + { + _pageSize = pageSize; + _offset = offset; + _options = options; + _cancellationToken = _options?.CancellationToken ?? default; + + _mockPagesData = MockPageResponseData.GetPages(pageSize, offset); + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + => ValueCollectionPageToken.FromResponse(page, _pageSize); + + public override IEnumerable GetRawPages() + { + foreach (ValueItemPage page in _mockPagesData) + { + // Simulate the pipeline checking for cancellation, + // which happens in the transport + _cancellationToken.ThrowIfCancellationRequested(); + + PipelineResponse response = new MockPageResponse(page); + yield return ClientResult.FromResponse(response); + } + } + + protected override IEnumerable GetValuesFromPage(ClientResult page) + { + PipelineResponse response = page.GetRawResponse(); + ValueItemPage valuePage = ValueItemPage.FromJson(response.Content); + return valuePage.Values; + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItem.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueItem.cs similarity index 68% rename from sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItem.cs rename to sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueItem.cs index 5265fc52d974..b748bc52ed3c 100644 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItem.cs +++ b/sdk/core/System.ClientModel/tests/client/TestClients/PaginatedCollectionClient/ValueItem.cs @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Text.Json; -namespace ClientModel.Tests.Paging; +namespace ClientModel.Tests.Collections; -// A mock model that illustrate values that can be returned in a page collection +// A mock model that illustrate value that can be returned in a page collection public class ValueItem { public ValueItem(int id, string value) @@ -15,6 +16,7 @@ public ValueItem(int id, string value) } public int Id { get; } + public string Value { get; } public string ToJson() => $"{{ \"id\" : {Id}, \"value\" : \"{Value}\" }}"; @@ -26,5 +28,11 @@ public static ValueItem FromJson(JsonElement element) return new ValueItem(id, value); } + public static ValueItem FromJson(BinaryData data) + { + using JsonDocument doc = JsonDocument.Parse(data); + return FromJson(doc.RootElement); + } + public override string ToString() => ToJson(); } diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageCollectionHelpers.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageCollectionHelpers.cs deleted file mode 100644 index 21b5fba1a4b2..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageCollectionHelpers.cs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Paging; - -internal class PageCollectionHelpers -{ - public static PageCollection Create(PageEnumerator enumerator) - => new EnumeratorPageCollection(enumerator); - - public static AsyncPageCollection CreateAsync(PageEnumerator enumerator) - => new AsyncEnumeratorPageCollection(enumerator); - - public static IEnumerable Create(PageResultEnumerator enumerator) - { - while (enumerator.MoveNext()) - { - yield return enumerator.Current; - } - } - - public static async IAsyncEnumerable CreateAsync(PageResultEnumerator enumerator) - { - while (await enumerator.MoveNextAsync().ConfigureAwait(false)) - { - yield return enumerator.Current; - } - } - - private class EnumeratorPageCollection : PageCollection - { - private readonly PageEnumerator _enumerator; - - public EnumeratorPageCollection(PageEnumerator enumerator) - { - _enumerator = enumerator; - } - - protected override PageResult GetCurrentPageCore() - => _enumerator.GetCurrentPage(); - - protected override IEnumerator> GetEnumeratorCore() - => _enumerator; - } - - private class AsyncEnumeratorPageCollection : AsyncPageCollection - { - private readonly PageEnumerator _enumerator; - - public AsyncEnumeratorPageCollection(PageEnumerator enumerator) - { - _enumerator = enumerator; - } - - protected override async Task> GetCurrentPageAsyncCore() - => await _enumerator.GetCurrentPageAsync().ConfigureAwait(false); - - protected override IAsyncEnumerator> GetAsyncEnumeratorCore(CancellationToken cancellationToken = default) - => _enumerator; - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageEnumerator.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageEnumerator.cs deleted file mode 100644 index 53527a1e1a18..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageEnumerator.cs +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Paging; - -internal abstract class PageEnumerator : PageResultEnumerator, - IAsyncEnumerator>, - IEnumerator> -{ - public abstract PageResult GetPageFromResult(ClientResult result); - - public PageResult GetCurrentPage() - { - if (Current is null) - { - return GetPageFromResult(GetFirst()); - } - - return ((IEnumerator>)this).Current; - } - - public async Task> GetCurrentPageAsync() - { - if (Current is null) - { - return GetPageFromResult(await GetFirstAsync().ConfigureAwait(false)); - } - - return ((IEnumerator>)this).Current; - } - - PageResult IEnumerator>.Current - { - get - { - if (Current is null) - { - return default!; - } - - return GetPageFromResult(Current); - } - } - - PageResult IAsyncEnumerator>.Current - { - get - { - if (Current is null) - { - return default!; - } - - return GetPageFromResult(Current); - } - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageResultEnumerator.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageResultEnumerator.cs deleted file mode 100644 index 64d2550f97db..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/Emitted/PageResultEnumerator.cs +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.Collections; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Paging; - -internal abstract class PageResultEnumerator : IAsyncEnumerator, IEnumerator -{ - private ClientResult? _current; - private bool _hasNext = true; - - public ClientResult Current => _current!; - - public abstract Task GetFirstAsync(); - - public abstract ClientResult GetFirst(); - - public abstract Task GetNextAsync(ClientResult result); - - public abstract ClientResult GetNext(ClientResult result); - - public abstract bool HasNext(ClientResult result); - - object IEnumerator.Current => ((IEnumerator)this).Current; - - public bool MoveNext() - { - if (!_hasNext) - { - return false; - } - - if (_current == null) - { - _current = GetFirst(); - } - else - { - _current = GetNext(_current); - } - - _hasNext = HasNext(_current); - return true; - } - - void IEnumerator.Reset() => _current = null; - - void IDisposable.Dispose() { } - - public async ValueTask MoveNextAsync() - { - if (!_hasNext) - { - return false; - } - - if (_current == null) - { - _current = await GetFirstAsync().ConfigureAwait(false); - } - else - { - _current = await GetNextAsync(_current).ConfigureAwait(false); - } - - _hasNext = HasNext(_current); - return true; - } - - ValueTask IAsyncDisposable.DisposeAsync() => default; -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockPagingData.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockPagingData.cs deleted file mode 100644 index 9a004c51eec0..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/MockData/MockPagingData.cs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel; -using System.Collections.Generic; -using System.Linq; - -namespace ClientModel.Tests.Paging; - -public class MockPagingData -{ - public const int Count = 16; - - public const string DefaultOrder = "asc"; - public const int DefaultPageSize = 8; - public const int DefaultOffset = 0; - - // Source of all the data - public static IEnumerable GetValues() - { - for (int i = 0; i < Count; i++) - { - yield return new ValueItem(i, $"{i}"); - } - } - - // Filters on top of data source - public static IEnumerable GetValues( - string? order, - int? pageSize, - int? offset) - { - order ??= DefaultOrder; - pageSize ??= DefaultPageSize; - offset ??= DefaultOffset; - - IEnumerable ordered = order == "asc" ? - GetValues() : - GetValues().Reverse(); - IEnumerable skipped = ordered.Skip(offset.Value); - IEnumerable page = skipped.Take(pageSize.Value); - - return page; - } - - // Turn data into a page result for protocol layer - public static ClientResult GetPageResult(IEnumerable values) - => ClientResult.FromResponse(new MockValueItemPageResponse(values)); -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClient.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClient.cs deleted file mode 100644 index c44c08c57ac3..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClient.cs +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Threading; - -namespace ClientModel.Tests.Paging; - -// A mock client implementation that illustrates paging patterns for client -// endpoints that have both convenience and protocol methods. -public class PagingClient -{ - private readonly ClientPipeline _pipeline; - private readonly Uri _endpoint; - - public PagingClient(PagingClientOptions options) - { - _pipeline = ClientPipeline.Create(options); - _endpoint = new Uri("https://www.paging.com"); - } - - public virtual AsyncPageCollection GetValuesAsync( - string? order = default, - int? pageSize = default, - int? offset = default, - CancellationToken cancellationToken = default) - { - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - order: order, - pageSize: pageSize, - offset: offset, - cancellationToken.ToRequestOptions()); - return PageCollectionHelpers.CreateAsync(enumerator); - } - - public virtual AsyncPageCollection GetValuesAsync( - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); - - ValuesPageToken token = ValuesPageToken.FromToken(firstPageToken); - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - token.Order, - token.PageSize, - token.Offset, - cancellationToken.ToRequestOptions()); - return PageCollectionHelpers.CreateAsync(enumerator); - } - - public virtual PageCollection GetValues( - string? order = default, - int? pageSize = default, - int? offset = default, - CancellationToken cancellationToken = default) - { - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - order: order, - pageSize: pageSize, - offset: offset, - cancellationToken.ToRequestOptions()); - return PageCollectionHelpers.Create(enumerator); - } - - public virtual PageCollection GetValues( - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); - - ValuesPageToken token = ValuesPageToken.FromToken(firstPageToken); - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - token.Order, - token.PageSize, - token.Offset, - cancellationToken.ToRequestOptions()); - return PageCollectionHelpers.Create(enumerator); - } - - public virtual IAsyncEnumerable GetValuesAsync( - string? order, - int? pageSize, - int? offset, - RequestOptions options) - { - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - order: order, - pageSize: pageSize, - offset: offset, - options); - return PageCollectionHelpers.CreateAsync(enumerator); - } - - public virtual IEnumerable GetValues( - string? order, - int? pageSize, - int? offset, - RequestOptions options) - { - ValuesPageEnumerator enumerator = new ValuesPageEnumerator( - _pipeline, - _endpoint, - order: order, - pageSize: pageSize, - offset: offset, - options); - return PageCollectionHelpers.Create(enumerator); - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClientOptions.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClientOptions.cs deleted file mode 100644 index 0813e7730122..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingClientOptions.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; - -namespace ClientModel.Tests.Paging; - -public class PagingClientOptions : ClientPipelineOptions -{ -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingProtocolClient.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingProtocolClient.cs deleted file mode 100644 index 05c4779bb792..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/PagingProtocolClient.cs +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; - -namespace ClientModel.Tests.Paging; - -// A mock client implementation that illustrates paging patterns for client -// endpoints that only have protocol methods. -public class PagingProtocolClient -{ - private readonly ClientPipeline _pipeline; - private readonly Uri _endpoint; - - public PagingProtocolClient(PagingClientOptions options) - { - _pipeline = ClientPipeline.Create(options); - _endpoint = new Uri("https://www.paging.com"); - } - - public virtual IAsyncEnumerable GetValuesAsync( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - PageResultEnumerator enumerator = new ValuesPageResultEnumerator( - _pipeline, - _endpoint, - order, - pageSize, - offset, - options); - return PageCollectionHelpers.CreateAsync(enumerator); - } - - public virtual IEnumerable GetValues( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - PageResultEnumerator enumerator = new ValuesPageResultEnumerator( - _pipeline, - _endpoint, - order, - pageSize, - offset, - options); - return PageCollectionHelpers.Create(enumerator); - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItemPage.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItemPage.cs deleted file mode 100644 index 8dbde03dacaa..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValueItemPage.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Text.Json; - -namespace ClientModel.Tests.Paging; - -// In a real client, this type would be generated but would be made internal. -// It corresponds to the REST API definition of the response that comes back -// with a list of items in a page. -internal class ValueItemPage -{ - protected ValueItemPage(List values) - { - Values = values; - } - - public IReadOnlyList Values { get; set; } - - public static ValueItemPage FromJson(BinaryData json) - { - List items = new(); - - using JsonDocument doc = JsonDocument.Parse(json); - foreach (JsonElement element in doc.RootElement.EnumerateArray()) - { - items.Add(ValueItem.FromJson(element)); - } - - return new ValueItemPage(items); - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageEnumerator.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageEnumerator.cs deleted file mode 100644 index 67d32ba833db..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageEnumerator.cs +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Paging; - -// Mocks a page enumerator a client would evolve to for paged endpoints when -// the client adds convenience methods. -internal class ValuesPageEnumerator : PageEnumerator -{ - private readonly ClientPipeline _pipeline; - private readonly Uri _endpoint; - - private readonly string? _order; - private readonly int? _pageSize; - - // This one is special - it keeps track of which page we're on. - private int? _offset; - - // We need two offsets to be able to create both page tokens. - private int _nextOffset; - - private readonly RequestOptions? _options; - - public ValuesPageEnumerator( - ClientPipeline pipeline, - Uri endpoint, - string? order, - int? pageSize, - int? offset, - RequestOptions? options) - { - _pipeline = pipeline; - _endpoint = endpoint; - - _order = order; - _pageSize = pageSize; - _offset = offset; - - _options = options; - } - - public override PageResult GetPageFromResult(ClientResult result) - { - PipelineResponse response = result.GetRawResponse(); - ValueItemPage pageModel = ValueItemPage.FromJson(response.Content); - - ValuesPageToken pageToken = ValuesPageToken.FromOptions(_order, _pageSize, _offset); - ValuesPageToken? nextPageToken = pageToken.GetNextPageToken(_nextOffset, MockPagingData.Count); - - return PageResult.Create(pageModel.Values, pageToken, nextPageToken, response); - } - - public override ClientResult GetFirst() - { - ClientResult result = GetValuesPage(_order, _pageSize, _offset); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return result; - } - - public override async Task GetFirstAsync() - { - ClientResult result = await GetValuesPageAsync(_order, _pageSize, _offset).ConfigureAwait(false); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return result; - } - - public override ClientResult GetNext(ClientResult result) - { - _offset = _nextOffset; - - ClientResult pageResult = GetValuesPage(_order, _pageSize, _offset); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return pageResult; - } - - public override async Task GetNextAsync(ClientResult result) - { - _offset = _nextOffset; - - ClientResult pageResult = await GetValuesPageAsync(_order, _pageSize, _offset).ConfigureAwait(false); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return pageResult; - } - - public override bool HasNext(ClientResult result) - { - return _nextOffset < MockPagingData.Count; - } - - // In a real client implementation, these would be the generated protocol - // method used to obtain a page of items. - internal virtual async Task GetValuesPageAsync( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - await Task.Delay(0); - IEnumerable values = MockPagingData.GetValues(order, pageSize, offset); - return MockPagingData.GetPageResult(values); - } - - internal virtual ClientResult GetValuesPage( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - IEnumerable values = MockPagingData.GetValues(order, pageSize, offset); - return MockPagingData.GetPageResult(values); - } - - // This helper method is specific to this mock enumerator implementation - private static int GetNextOffset(int? offset, int? pageSize) - { - offset ??= MockPagingData.DefaultOffset; - pageSize ??= MockPagingData.DefaultPageSize; - return offset.Value + pageSize.Value; - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageResultEnumerator.cs b/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageResultEnumerator.cs deleted file mode 100644 index 655d49f4bfe9..000000000000 --- a/sdk/core/System.ClientModel/tests/client/TestClients/PagingClient/ValuesPageResultEnumerator.cs +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Paging; - -// Mocks a page result enumerator a client would have for paged endpoints when -// those endpoints only have protocol methods on the client. -internal class ValuesPageResultEnumerator : PageResultEnumerator -{ - private readonly ClientPipeline _pipeline; - private readonly Uri _endpoint; - - private readonly string? _order; - private readonly int? _pageSize; - - // This one is special - it keeps track of which page we're on. - private int? _offset; - - // We need two offsets to be able to create both page tokens. - private int _nextOffset; - - private readonly RequestOptions? _options; - - public ValuesPageResultEnumerator( - ClientPipeline pipeline, - Uri endpoint, - string? order, - int? pageSize, - int? offset, - RequestOptions? options) - { - _pipeline = pipeline; - _endpoint = endpoint; - - _order = order; - _pageSize = pageSize; - _offset = offset; - - _options = options; - } - - public override ClientResult GetFirst() - { - ClientResult result = GetValuesPage(_order, _pageSize, _offset); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return result; - } - - public override async Task GetFirstAsync() - { - ClientResult result = await GetValuesPageAsync(_order, _pageSize, _offset).ConfigureAwait(false); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return result; - } - - public override ClientResult GetNext(ClientResult result) - { - _offset = _nextOffset; - - ClientResult pageResult = GetValuesPage(_order, _pageSize, _offset); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return pageResult; - } - - public override async Task GetNextAsync(ClientResult result) - { - _offset = _nextOffset; - - ClientResult pageResult = await GetValuesPageAsync(_order, _pageSize, _offset).ConfigureAwait(false); - - _nextOffset = GetNextOffset(_offset, _pageSize); - - return pageResult; - } - - public override bool HasNext(ClientResult result) - { - return _nextOffset < MockPagingData.Count; - } - - // In a real client implementation, these would be the generated protocol - // method used to obtain a page of items. - internal virtual async Task GetValuesPageAsync( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - await Task.Delay(0); - IEnumerable values = MockPagingData.GetValues(order, pageSize, offset); - return MockPagingData.GetPageResult(values); - } - - internal virtual ClientResult GetValuesPage( - string? order, - int? pageSize, - int? offset, - RequestOptions? options = default) - { - IEnumerable values = MockPagingData.GetValues(order, pageSize, offset); - return MockPagingData.GetPageResult(values); - } - - // This helper method is specific to this mock enumerator implementation - private static int GetNextOffset(int? offset, int? pageSize) - { - offset ??= MockPagingData.DefaultOffset; - pageSize ??= MockPagingData.DefaultPageSize; - return offset.Value + pageSize.Value; - } -} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/AsyncStreamedValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/AsyncStreamedValueCollectionResult.cs new file mode 100644 index 000000000000..af6412060643 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/AsyncStreamedValueCollectionResult.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Net.ServerSentEvents; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Collections; + +// This type is public to enable test scenarios - in a real client it would be +// an internal type. +public class AsyncStreamedValueCollectionResult : AsyncCollectionResult +{ + private readonly RequestOptions? _options; + + public AsyncStreamedValueCollectionResult(RequestOptions? options) + { + _options = options; + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + // continuation not supported in this mock implentation + => null; + + public override async IAsyncEnumerable GetRawPagesAsync() + { + await Task.Delay(0, _options?.CancellationToken ?? default); + + // Only one response holds all the streamed data in this mock implementation + PipelineResponse response = new MockStreamedResponse(MockStreamedData.DefaultMockContent); + yield return ClientResult.FromResponse(response); + } + + // The following method is added for observability of the response + // in ResponseStreamisDisposedTest + public IAsyncEnumerable GetPageValuesAsync(ClientResult page) + { + return GetValuesFromPageAsync(page); + } + + protected override async IAsyncEnumerable GetValuesFromPageAsync(ClientResult page) + { + using PipelineResponse response = page.GetRawResponse(); + Stream contentStream = response.ContentStream ?? response.Content.ToStream(); + + SseParser parser = SseParser.Create(contentStream, (_, bytes) => bytes.ToArray()); + IAsyncEnumerable> enumerable = parser.EnumerateAsync(); + await foreach (SseItem item in enumerable) + { + if (!MockStreamedData.IsTerminalEvent(item.Data)) + { + yield return StreamedValue.FromJson(item.Data); + } + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/Internal/System.Net.ServerSentEvents.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/Internal/System.Net.ServerSentEvents.cs new file mode 100644 index 000000000000..c2801f19c102 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/Internal/System.Net.ServerSentEvents.cs @@ -0,0 +1,619 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file contains a source copy of: +// https://github.com/dotnet/runtime/tree/2bd15868f12ace7cee9999af61d5c130b2603f04/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents +// Once the System.Net.ServerSentEvents package is available, this file should be removed and replaced with a package reference. +// +// The only changes made to this code from the original are: +// - Enabled nullable reference types at file scope, and use a few null suppression operators to work around the lack of [NotNull] +// - Put into a single file for ease of management (it should not be edited in this repo). +// - Changed public types to be internal. +// - Removed a use of a [NotNull] attribute to assist in netstandard2.0 compilation. +// - Replaced a reference to a .resx string with an inline constant. + +#nullable enable + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using System.Threading; + +namespace System.Net.ServerSentEvents; + +/// Represents a server-sent event. +/// Specifies the type of data payload in the event. +internal readonly struct SseItem +{ + /// Initializes the server-sent event. + /// The event's payload. + /// The event's type. + public SseItem(T data, string eventType) + { + Data = data; + EventType = eventType; + } + + /// Gets the event's payload. + public T Data { get; } + + /// Gets the event's type. + public string EventType { get; } +} + +/// Encapsulates a method for parsing the bytes payload of a server-sent event. +/// Specifies the type of the return value of the parser. +/// The event's type. +/// The event's payload bytes. +/// The parsed . +internal delegate T SseItemParser(string eventType, ReadOnlySpan data); + +/// Provides a parser for parsing server-sent events. +internal static class SseParser +{ + /// The default ("message") for an event that did not explicitly specify a type. + public const string EventTypeDefault = "message"; + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// The stream containing the data to parse. + /// + /// The enumerable of strings, which may be enumerated synchronously or asynchronously. The strings + /// are decoded from the UTF8-encoded bytes of the payload of each event. + /// + /// is null. + /// + /// This overload has behavior equivalent to calling with a delegate + /// that decodes the data of each event using 's GetString method. + /// + public static SseParser Create(Stream sseStream) => + Create(sseStream, static (_, bytes) => Utf8GetString(bytes)); + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// Specifies the type of data in each event. + /// The stream containing the data to parse. + /// The parser to use to transform each payload of bytes into a data element. + /// The enumerable, which may be enumerated synchronously or asynchronously. + /// is null. + /// is null. + public static SseParser Create(Stream sseStream, SseItemParser itemParser) => + new SseParser( + sseStream ?? throw new ArgumentNullException(nameof(sseStream)), + itemParser ?? throw new ArgumentNullException(nameof(itemParser))); + + /// Encoding.UTF8.GetString(bytes) + internal static string Utf8GetString(ReadOnlySpan bytes) + { +#if NET + return Encoding.UTF8.GetString(bytes); +#else + byte[] array = bytes.ToArray(); + return array.Length == 0 ? + string.Empty : + Encoding.UTF8.GetString(array); +#endif + } +} + +/// Provides a parser for server-sent events information. +/// Specifies the type of data parsed from an event. +internal sealed class SseParser +{ + // For reference: + // Specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events + + /// Carriage Return. + private const byte CR = (byte)'\r'; + /// Line Feed. + private const byte LF = (byte)'\n'; + /// Carriage Return Line Feed. + private static ReadOnlySpan CRLF => "\r\n"u8; + + /// The default size of an ArrayPool buffer to rent. + /// Larger size used by default to minimize number of reads. Smaller size used in debug to stress growth/shifting logic. + private const int DefaultArrayPoolRentSize = +#if DEBUG + 16; +#else + 1024; +#endif + + /// The stream to be parsed. + private readonly Stream _stream; + /// The parser delegate used to transform bytes into a . + private readonly SseItemParser _itemParser; + + /// Indicates whether the enumerable has already been used for enumeration. + private int _used; + + /// Buffer, either empty or rented, containing the data being read from the stream while looking for the next line. + private byte[] _lineBuffer = new byte[0]; + /// The starting offset of valid data in . + private int _lineOffset; + /// The length of valid data in , starting from . + private int _lineLength; + /// The index in where a newline ('\r', '\n', or "\r\n") was found. + private int _newlineIndex; + /// The index in of characters already checked for newlines. + /// + /// This is to avoid O(LineLength^2) behavior in the rare case where we have long lines that are built-up over multiple reads. + /// We want to avoid re-checking the same characters we've already checked over and over again. + /// + private int _lastSearchedForNewline; + /// Set when eof has been reached in the stream. + private bool _eof; + + /// Rented buffer containing buffered data for the next event. + private byte[]? _dataBuffer; + /// The length of valid data in , starting from index 0. + private int _dataLength; + /// Whether data has been appended to . + /// This can be different than != 0 if empty data was appended. + private bool _dataAppended; + + /// The event type for the next event. + private string _eventType = SseParser.EventTypeDefault; + + /// Initialize the enumerable. + /// The stream to parse. + /// The function to use to parse payload bytes into a . + internal SseParser(Stream stream, SseItemParser itemParser) + { + _stream = stream; + _itemParser = itemParser; + } + + /// Gets an enumerable of the server-sent events from this parser. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + public IEnumerable> Enumerate() + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (FillLineBuffer() != 0 && _lineLength < Utf8Bom.Length) + ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // the newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // we must have CR and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = _lineOffset + _lineLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + FillLineBuffer(); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets an asynchronous enumerable of the server-sent events from this parser. + /// The cancellation token to use to cancel the enumeration. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + /// The enumeration was canceled. Such an exception may propagate out of a call to . + public async IAsyncEnumerable> EnumerateAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (await FillLineBufferAsync(cancellationToken).ConfigureAwait(false) != 0 && _lineLength < Utf8Bom.Length) + ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // newline is CR, and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = searchOffset + searchLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + await FillLineBufferAsync(cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets the next index and length with which to perform a newline search. + private void GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength) + { + if (_lastSearchedForNewline > _lineOffset) + { + searchOffset = _lastSearchedForNewline; + searchLength = _lineLength - (_lastSearchedForNewline - _lineOffset); + } + else + { + searchOffset = _lineOffset; + searchLength = _lineLength; + } + + Debug.Assert(searchOffset >= _lineOffset, $"{searchOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineOffset + _lineLength, $"{searchOffset}, {_lineOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineBuffer.Length, $"{searchOffset}, {_lineBuffer.Length}"); + + Debug.Assert(searchLength >= 0, $"{searchLength}"); + Debug.Assert(searchLength <= _lineLength, $"{searchLength}, {_lineLength}"); + } + + private int GetNewLineLength() + { + Debug.Assert(_newlineIndex - _lineOffset < _lineLength, "Expected to be positioned at a non-empty newline"); + return _lineBuffer.AsSpan(_newlineIndex, _lineLength - (_newlineIndex - _lineOffset)).StartsWith(CRLF) ? 2 : 1; + } + + /// + /// If there's no room remaining in the line buffer, either shifts the contents + /// left or grows the buffer in order to make room for the next read. + /// + private void ShiftOrGrowLineBufferIfNecessary() + { + // If data we've read is butting up against the end of the buffer and + // it's not taking up the entire buffer, slide what's there down to + // the beginning, making room to read more data into the buffer (since + // there's no newline in the data that's there). Otherwise, if the whole + // buffer is full, grow the buffer to accommodate more data, since, again, + // what's there doesn't contain a newline and thus a line is longer than + // the current buffer accommodates. + if (_lineOffset + _lineLength == _lineBuffer.Length) + { + if (_lineOffset != 0) + { + _lineBuffer.AsSpan(_lineOffset, _lineLength).CopyTo(_lineBuffer); + if (_lastSearchedForNewline >= 0) + { + _lastSearchedForNewline -= _lineOffset; + } + _lineOffset = 0; + } + else if (_lineLength == _lineBuffer.Length) + { + GrowBuffer(ref _lineBuffer!, _lineBuffer.Length * 2); + } + } + } + + /// Processes a complete line from the SSE stream. + /// The parsed item if the method returns true. + /// How many characters to advance in the line buffer. + /// true if an SSE item was successfully parsed; otherwise, false. + private bool ProcessLine(out SseItem sseItem, out int advance) + { + ReadOnlySpan line = _lineBuffer.AsSpan(_lineOffset, _newlineIndex - _lineOffset); + + // Spec: "If the line is empty (a blank line) Dispatch the event" + if (line.IsEmpty) + { + advance = GetNewLineLength(); + + if (_dataAppended) + { + sseItem = new SseItem(_itemParser(_eventType, _dataBuffer.AsSpan(0, _dataLength)), _eventType); + _eventType = SseParser.EventTypeDefault; + _dataLength = 0; + _dataAppended = false; + return true; + } + + sseItem = default; + return false; + } + + // Find the colon separating the field name and value. + int colonPos = line.IndexOf((byte)':'); + ReadOnlySpan fieldName; + ReadOnlySpan fieldValue; + if (colonPos >= 0) + { + // Spec: "Collect the characters on the line before the first U+003A COLON character (:), and let field be that string." + fieldName = line.Slice(0, colonPos); + + // Spec: "Collect the characters on the line after the first U+003A COLON character (:), and let value be that string. + // If value starts with a U+0020 SPACE character, remove it from value." + fieldValue = line.Slice(colonPos + 1); + if (!fieldValue.IsEmpty && fieldValue[0] == (byte)' ') + { + fieldValue = fieldValue.Slice(1); + } + } + else + { + // Spec: "using the whole line as the field name, and the empty string as the field value." + fieldName = line; + fieldValue = new(); + } + + if (fieldName.SequenceEqual("data"u8)) + { + // Spec: "Append the field value to the data buffer, then append a single U+000A LINE FEED (LF) character to the data buffer." + // Spec: "If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer." + + // If there's nothing currently in the data buffer and we can easily detect that this line is immediately followed by + // an empty line, we can optimize it to just handle the data directly from the line buffer, rather than first copying + // into the data buffer and dispatching from there. + if (!_dataAppended) + { + int newlineLength = GetNewLineLength(); + ReadOnlySpan remainder = _lineBuffer.AsSpan(_newlineIndex + newlineLength, _lineLength - line.Length - newlineLength); + if (!remainder.IsEmpty && + (remainder[0] is LF || (remainder[0] is CR && remainder.Length > 1))) + { + advance = line.Length + newlineLength + (remainder.StartsWith(CRLF) ? 2 : 1); + sseItem = new SseItem(_itemParser(_eventType, fieldValue), _eventType); + _eventType = SseParser.EventTypeDefault; + return true; + } + } + + // We need to copy the data from the data buffer to the line buffer. Make sure there's enough room. + if (_dataBuffer is null || _dataLength + _lineLength + 1 > _dataBuffer.Length) + { + GrowBuffer(ref _dataBuffer, _dataLength + _lineLength + 1); + } + + // Append a newline if there's already content in the buffer. + // Then copy the field value to the data buffer + if (_dataAppended) + { + _dataBuffer![_dataLength++] = LF; + } + fieldValue.CopyTo(_dataBuffer.AsSpan(_dataLength)); + _dataLength += fieldValue.Length; + _dataAppended = true; + } + else if (fieldName.SequenceEqual("event"u8)) + { + // Spec: "Set the event type buffer to field value." + _eventType = SseParser.Utf8GetString(fieldValue); + } + else if (fieldName.SequenceEqual("id"u8)) + { + // Spec: "If the field value does not contain U+0000 NULL, then set the last event ID buffer to the field value. Otherwise, ignore the field." + if (fieldValue.IndexOf((byte)'\0') < 0) + { + // Note that fieldValue might be empty, in which case LastEventId will naturally be reset to the empty string. This is per spec. + LastEventId = SseParser.Utf8GetString(fieldValue); + } + } + else if (fieldName.SequenceEqual("retry"u8)) + { + // Spec: "If the field value consists of only ASCII digits, then interpret the field value as an integer in base ten, + // and set the event stream's reconnection time to that integer. Otherwise, ignore the field." + if (long.TryParse( +#if NET7_0_OR_GREATER + fieldValue, +#else + SseParser.Utf8GetString(fieldValue), +#endif + NumberStyles.None, CultureInfo.InvariantCulture, out long milliseconds)) + { + ReconnectionInterval = TimeSpan.FromMilliseconds(milliseconds); + } + } + else + { + // We'll end up here if the line starts with a colon, producing an empty field name, or if the field name is otherwise unrecognized. + // Spec: "If the line starts with a U+003A COLON character (:) Ignore the line." + // Spec: "Otherwise, The field is ignored" + } + + advance = line.Length + GetNewLineLength(); + sseItem = default; + return false; + } + + /// Gets the last event ID. + /// This value is updated any time a new last event ID is parsed. It is not reset between SSE items. + public string LastEventId { get; private set; } = string.Empty; // Spec: "must be initialized to the empty string" + + /// Gets the reconnection interval. + /// + /// If no retry event was received, this defaults to , and it will only + /// ever be in that situation. If a client wishes to retry, the server-sent + /// events specification states that the interval may then be decided by the client implementation and should be a + /// few seconds. + /// + public TimeSpan ReconnectionInterval { get; private set; } = Timeout.InfiniteTimeSpan; + + /// Transitions the object to a used state, throwing if it's already been used. + private void ThrowIfNotFirstEnumeration() + { + if (Interlocked.Exchange(ref _used, 1) != 0) + { + throw new InvalidOperationException("The enumerable may be enumerated only once."); + } + } + + /// Reads data from the stream into the line buffer. + private int FillLineBuffer() + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = _stream.Read( +#if NET + _lineBuffer.AsSpan(offset)); +#else + _lineBuffer, offset, _lineBuffer.Length - offset); +#endif + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Reads data asynchronously from the stream into the line buffer. + private async ValueTask FillLineBufferAsync(CancellationToken cancellationToken) + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = await +#if NET + _stream.ReadAsync(_lineBuffer.AsMemory(offset), cancellationToken) +#else + new ValueTask(_stream.ReadAsync(_lineBuffer, offset, _lineBuffer.Length - offset, cancellationToken)) +#endif + .ConfigureAwait(false); + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Gets the UTF8 BOM. + private static ReadOnlySpan Utf8Bom => new byte[] { 0xEF, 0xBB, 0xBF }; + + /// Called at the beginning of processing to skip over an optional UTF8 byte order mark. + private void SkipBomIfPresent() + { + Debug.Assert(_lineOffset == 0, $"Expected _lineOffset == 0, got {_lineOffset}"); + + if (_lineBuffer.AsSpan(0, _lineLength).StartsWith(Utf8Bom)) + { + _lineOffset += 3; + _lineLength -= 3; + } + } + + /// Grows the buffer, returning the existing one to the ArrayPool and renting an ArrayPool replacement. + private static void GrowBuffer(ref byte[]? buffer, int minimumLength) + { + byte[]? toReturn = buffer; + buffer = ArrayPool.Shared.Rent(Math.Max(minimumLength, DefaultArrayPoolRentSize)); + if (toReturn is not null) + { + Array.Copy(toReturn, buffer, toReturn.Length); + ArrayPool.Shared.Return(toReturn); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedData.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedData.cs new file mode 100644 index 000000000000..20441c92df8d --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedData.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace ClientModel.Tests.Collections; + +public class MockStreamedData +{ + public const int TotalItemCount = 3; + + private static ReadOnlySpan TerminalData => "[DONE]"u8; + + // Note: need extra line because raw string literal removes \n from final line. + internal const string DefaultMockContent = """ + event: event.0 + data: { "id": 0, "value": "0" } + + event: event.1 + data: { "id": 1, "value": "1" } + + event: event.2 + data: { "id": 2, "value": "2" } + + event: done + data: [DONE] + + + """; + + public static bool IsTerminalEvent(byte[] data) + { + return data.AsSpan().SequenceEqual(TerminalData); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedResponse.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedResponse.cs new file mode 100644 index 000000000000..ea0ac6f895f3 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/MockData/MockStreamedResponse.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel.Primitives; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Collections; + +public class MockStreamedResponse : PipelineResponse +{ + public MockStreamedResponse(string content) + { + Content = BinaryData.FromString(content); + } + + public override int Status => 200; + + public override string ReasonPhrase => "OK"; + + public override Stream? ContentStream + { + get => null; + set => throw new NotImplementedException(); + } + + public override BinaryData Content { get; } + + public bool IsDisposed { get; private set; } + + protected override PipelineResponseHeaders HeadersCore + => throw new NotImplementedException(); + + public override BinaryData BufferContent(CancellationToken cancellationToken = default) + => Content; + + public override ValueTask BufferContentAsync(CancellationToken cancellationToken = default) + => new(Content); + + public override void Dispose() + { + IsDisposed = true; + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClient.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClient.cs new file mode 100644 index 000000000000..2e6a4d308f1b --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClient.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Threading; + +namespace ClientModel.Tests.Collections; + +// A reference implementation that illustrates client patterns for streaming +// service endpoints for clients that have both convenience and protocol methods. +public class StreamedCollectionClient +{ + public StreamedCollectionClient(StreamedCollectionClientOptions? options = default) + { + } + + public virtual AsyncCollectionResult GetValuesAsync(CancellationToken cancellationToken = default) + { + return new AsyncStreamedValueCollectionResult(cancellationToken.ToRequestOptions()); + } + + public virtual CollectionResult GetValues(CancellationToken cancellationToken = default) + { + return new StreamedValueCollectionResult(cancellationToken.ToRequestOptions()); + } + + public virtual AsyncCollectionResult GetValuesAsync(RequestOptions? options) + { + return new AsyncStreamedValueCollectionResult(options); + } + + public virtual CollectionResult GetValues(RequestOptions? options) + { + return new StreamedValueCollectionResult(options); + } +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClientOptions.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClientOptions.cs new file mode 100644 index 000000000000..1e9627ddf185 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedCollectionClientOptions.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace ClientModel.Tests.Collections; + +public class StreamedCollectionClientOptions : ClientPipelineOptions +{ +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValue.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValue.cs new file mode 100644 index 000000000000..cdd8c632783d --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValue.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text.Json; + +namespace ClientModel.Tests.Collections; + +// A mock model that illustrates values that can be returned in a streamed collection +public class StreamedValue +{ + public StreamedValue(int id, string value) + { + Id = id; + Value = value; + } + + public int Id { get; } + + public string Value { get; } + + public string ToJson() => $"{{ \"id\" : {Id}, \"value\" : \"{Value}\" }}"; + + public static StreamedValue FromJson(JsonElement element) + { + int id = element.GetProperty("id").GetInt32(); + string value = element.GetProperty("value").GetString()!; + return new StreamedValue(id, value); + } + + public static StreamedValue FromJson(byte[] data) + { + using JsonDocument doc = JsonDocument.Parse(data); + return FromJson(doc.RootElement); + } + + public override string ToString() => ToJson(); +} diff --git a/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValueCollectionResult.cs b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValueCollectionResult.cs new file mode 100644 index 000000000000..ac5fdee4b60a --- /dev/null +++ b/sdk/core/System.ClientModel/tests/client/TestClients/StreamedCollectionClient/StreamedValueCollectionResult.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Net.ServerSentEvents; + +namespace ClientModel.Tests.Collections; + +// This type is public to enable test scenarios - in a real client it would be +// an internal type. +public class StreamedValueCollectionResult : CollectionResult +{ + private readonly RequestOptions? _options; + + public StreamedValueCollectionResult(RequestOptions? options) + { + _options = options; + } + + public override ContinuationToken? GetContinuationToken(ClientResult page) + // continuation not supported in this mock implentation + => null; + + public override IEnumerable GetRawPages() + { + // Only one response holds all the streamed data in this mock implementation + PipelineResponse response = new MockStreamedResponse(MockStreamedData.DefaultMockContent); + yield return ClientResult.FromResponse(response); + } + + // The following method is added for observability of the response + // in ResponseStreamisDisposedTest + public IEnumerable GetPageValues(ClientResult page) + { + return GetValuesFromPage(page); + } + + protected override IEnumerable GetValuesFromPage(ClientResult page) + { + using PipelineResponse response = page.GetRawResponse(); + Stream contentStream = response.ContentStream ?? response.Content.ToStream(); + + SseParser parser = SseParser.Create(contentStream, (_, bytes) => bytes.ToArray()); + IEnumerable> enumerable = parser.Enumerate(); + + foreach (SseItem item in enumerable) + { + if (!MockStreamedData.IsTerminalEvent(item.Data)) + { + yield return StreamedValue.FromJson(item.Data); + } + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs deleted file mode 100644 index fe511057338c..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using ClientModel.Tests.Internal.Mocks; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Convenience; - -public class AsyncServerSentEventEnumerableTests -{ - [Test] - public async Task EnumeratesEvents() - { - using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); - AsyncServerSentEventEnumerable enumerable = new(contentStream); - - List events = new(); - - await foreach (ServerSentEvent sse in enumerable) - { - events.Add(sse); - } - - Assert.AreEqual(4, events.Count); - - for (int i = 0; i < 3; i++) - { - Assert.AreEqual($"event.{i}", events[i].EventType); - Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", events[i].Data); - } - } - - [Test] - public void ThrowsIfCancelled() - { - CancellationToken token = new(true); - - using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); - AsyncServerSentEventEnumerable enumerable = new(contentStream); - IAsyncEnumerator enumerator = enumerable.GetAsyncEnumerator(token); - - Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs deleted file mode 100644 index 03a28295bb76..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Primitives; -using System.Threading; -using System.Threading.Tasks; -using Azure.Core.TestFramework; -using ClientModel.Tests.Internal.Mocks; -using NUnit.Framework; -using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; - -namespace System.ClientModel.Tests.Convenience; - -public class ClientResultCollectionTests : SyncAsyncTestBase -{ - public ClientResultCollectionTests(bool isAsync) : base(isAsync) - { - } - - [Test] - public async Task EnumeratesModelValues() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync(); - - int i = 0; - await foreach (MockJsonModel model in models) - { - Assert.AreEqual(i, model.IntValue); - Assert.AreEqual(i.ToString(), model.StringValue); - - i++; - } - - Assert.AreEqual(i, 3); - } - - [Test] - public async Task ModelCollectionDelaysSendingRequest() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync(); - - Assert.IsFalse(client.ProtocolMethodCalled); - - int i = 0; - await foreach (MockJsonModel model in models) - { - Assert.AreEqual(i, model.IntValue); - Assert.AreEqual(i.ToString(), model.StringValue); - - i++; - } - - Assert.AreEqual(3, i); - Assert.IsTrue(client.ProtocolMethodCalled); - } - - [Test] - public void ModelCollectionThrowsIfCancelled() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync(); - - // Set it to `cancelled: true` to validate functionality. - CancellationToken token = new(true); - - Assert.ThrowsAsync(async () => - { - await foreach (MockJsonModel model in models.WithCancellation(token)) - { - } - }); - } - - [Test] - public async Task ModelCollectionDisposesStream() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync(); - - await foreach (MockJsonModel model in models) - { - } - - PipelineResponse response = models.GetRawResponse(); - Assert.Throws(() => { var p = response.ContentStream!.Position; }); - } - - [Test] - public void ModelCollectionGetRawResponseThrowsBeforeEnumerated() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync(); - Assert.Throws(() => { PipelineResponse response = models.GetRawResponse(); }); - } - - [Test] - public async Task StopsOnStringBasedTerminalEvent() - { - MockSseClient client = new(); - AsyncCollectionResult models = client.GetModelsStreamingAsync("[DONE]"); - - bool empty = true; - await foreach (MockJsonModel model in models) - { - empty = false; - } - - Assert.IsNotNull(models); - Assert.AreEqual("[DONE]", models.GetRawResponse().Content.ToString()); - Assert.IsTrue(empty); - } - - [Test] - public async Task EnumeratesDataValues() - { - MockSseClient client = new(); - ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); - - int i = 0; - await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents()) - { - MockJsonModel? model = data.ToObjectFromJson(); - - Assert.AreEqual(i, model?.IntValue); - Assert.AreEqual(i.ToString(), model?.StringValue); - - i++; - } - - Assert.AreEqual(3, i); - } - - [Test] - public void DataCollectionThrowsIfCancelled() - { - MockSseClient client = new(); - ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); - - // Set it to `cancelled: true` to validate functionality. - CancellationToken token = new(true); - - Assert.ThrowsAsync(async () => - { - await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents().WithCancellation(token)) - { - } - }); - } - - [Test] - public async Task DataCollectionDoesNotDisposeStream() - { - MockSseClient client = new(); - ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); - - await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents()) - { - } - - Assert.DoesNotThrow(() => { var p = result.GetRawResponse().ContentStream!.Position; }); - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs deleted file mode 100644 index d32835fa7486..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Collections.Generic; -using System.IO; -using ClientModel.Tests.Internal.Mocks; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Convenience; - -public class ServerSentEventEnumerableTests -{ - [Test] - public void EnumeratesEvents() - { - using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); - ServerSentEventEnumerable enumerable = new(contentStream); - - List events = new(); - - foreach (ServerSentEvent sse in enumerable) - { - events.Add(sse); - } - - Assert.AreEqual(4, events.Count); - - for (int i = 0; i < 3; i++) - { - Assert.AreEqual($"event.{i}", events[i].EventType); - Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", events[i].Data); - } - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs deleted file mode 100644 index 8807aaa805b3..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Convenience; - -public class ServerSentEventFieldTests -{ - [Test] - public void ParsesEventField() - { - string line = "event: event.name"; - ServerSentEventField field = new(line); - - Assert.AreEqual(ServerSentEventFieldKind.Event, field.FieldType); - Assert.IsTrue("event.name".AsSpan().SequenceEqual(field.Value.Span)); - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs deleted file mode 100644 index 67365117e576..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using ClientModel.Tests.Internal.Mocks; -using NUnit.Framework; -using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; - -namespace System.ClientModel.Tests.Convenience; - -public class ServerSentEventReaderTests : SyncAsyncTestBase -{ - public ServerSentEventReaderTests(bool isAsync) : base(isAsync) - { - } - - [Test] - public async Task GetsEventsFromStream() - { - Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); - ServerSentEventReader reader = new(contentStream); - - List events = new(); - ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); - while (ssEvent is not null) - { - events.Add(ssEvent.Value); - ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); - } - - Assert.AreEqual(events.Count, 4); - - for (int i = 0; i < 3; i++) - { - ServerSentEvent sse = events[i]; - Assert.AreEqual($"event.{i}", sse.EventType); - Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", sse.Data); - } - - Assert.AreEqual("done", events[3].EventType); - Assert.AreEqual("[DONE]", events[3].Data); - } - - [Test] - public async Task HandlesNullLine() - { - Stream contentStream = BinaryData.FromString(string.Empty).ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); - Assert.IsNull(ssEvent); - } - - [Test] - public async Task DiscardsCommentLine() - { - Stream contentStream = BinaryData.FromString(": comment").ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); - Assert.IsNull(ssEvent); - } - - [Test] - public async Task HandlesIgnoreLine() - { - Stream contentStream = BinaryData.FromString(""" - ignore: noop - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - - Assert.IsNull(sse); - } - - [Test] - public async Task HandlesDoneEvent() - { - Stream contentStream = BinaryData.FromString("event: stop\ndata: ~stop~\n\n").ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - - Assert.IsNotNull(sse); - - Assert.AreEqual("stop", sse.Value.EventType); - Assert.AreEqual("~stop~", sse.Value.Data); - - Assert.AreEqual(string.Empty, reader.LastEventId); - Assert.AreEqual(Timeout.InfiniteTimeSpan, reader.ReconnectionInterval); - } - - [Test] - public async Task ConcatenatesDataLines() - { - Stream contentStream = BinaryData.FromString(""" - data: YHOO - data: +2 - data: 10 - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - - Assert.IsNotNull(sse); - - Assert.AreEqual("YHOO\n+2\n10", sse.Value.Data); - - Assert.AreEqual(string.Empty, reader.LastEventId); - Assert.AreEqual(Timeout.InfiniteTimeSpan, reader.ReconnectionInterval); - } - - [Test] - public async Task DefaultsEventTypeToMessage() - { - Stream contentStream = BinaryData.FromString(""" - data: data - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - - Assert.IsNotNull(sse); - - Assert.AreEqual("message", sse.Value.EventType); - Assert.AreEqual("data", sse.Value.Data); - } - - [Test] - public async Task SecondTestCaseFromSpec() - { - // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - Stream contentStream = BinaryData.FromString(""" - : test stream - - data: first event - id: 1 - - data:second event - id - - data: third event - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - List events = new(); - List ids = new(); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - while (sse is not null) - { - events.Add(sse.Value); - ids.Add(reader.LastEventId.ToString()); - - sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - } - - Assert.AreEqual(3, events.Count); - - Assert.AreEqual("first event", events[0].Data); - Assert.AreEqual("1", ids[0]); - - Assert.AreEqual("second event", events[1].Data); - Assert.AreEqual(string.Empty, ids[1]); - - Assert.AreEqual(" third event", events[2].Data); - Assert.AreEqual(string.Empty, ids[2]); - } - - [Test] - public async Task ThirdSpecTestCase() - { - // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - Stream contentStream = BinaryData.FromString(""" - data - - data - data - - data: - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - List events = new(); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - while (sse is not null) - { - events.Add(sse.Value); - sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - } - - Assert.AreEqual(2, events.Count); - Assert.AreEqual(0, events[0].Data.Length); - Assert.AreEqual("\n", events[1].Data); - } - - [Test] - public async Task FourthSpecTestCase() - { - // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - Stream contentStream = BinaryData.FromString(""" - data:test - - data: test - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - List events = new(); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - while (sse is not null) - { - events.Add(sse.Value); - sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - } - - Assert.AreEqual(2, events.Count); - Assert.AreEqual(events[0].Data, events[1].Data); - } - - [Test] - public async Task SetsReconnectionInterval() - { - Stream contentStream = BinaryData.FromString(""" - data: test - - data: test - retry: 2500 - - data: test - retry: - - - """).ToStream(); - ServerSentEventReader reader = new(contentStream); - - List events = new(); - List retryValues = new(); - - ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - while (sse is not null) - { - events.Add(sse.Value); - retryValues.Add(reader.ReconnectionInterval); - - sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); - } - - Assert.AreEqual(3, events.Count); - - // Defaults to infinite timespan - Assert.AreEqual("test", events[0].Data); - Assert.AreEqual(Timeout.InfiniteTimeSpan, retryValues[0]); - - Assert.AreEqual("test", events[1].Data); - Assert.AreEqual(new TimeSpan(0, 0, 0, 2, 500), retryValues[1]); - - // Ignores invalid values - Assert.AreEqual("test", events[2].Data); - Assert.AreEqual(new TimeSpan(0, 0, 0, 2, 500), retryValues[2]); - } - - [Test] - public void ThrowsIfCancelled() - { - CancellationToken token = new(true); - - using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); - ServerSentEventReader reader = new(contentStream); - - Assert.ThrowsAsync(async () - => await reader.TryGetNextEventAsync(token)); - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs deleted file mode 100644 index 34d156fc486b..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Internal; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Diagnostics; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Azure.Core.TestFramework; -using ClientModel.Tests.Mocks; - -namespace ClientModel.Tests.Internal.Mocks; - -// Note: keeping this mock client used to illustrate SSE usage patterns in -// Tests.Internal for now as it needs access to internal types. Once we are -// able to port to a solution that uses the public BCL SseParser type, this -// will no longer be needed. -public class MockSseClient -{ - // Note: raw string literal removes \n from final line. - internal const string DefaultMockContent = """ - event: event.0 - data: { "IntValue": 0, "StringValue": "0" } - - event: event.1 - data: { "IntValue": 1, "StringValue": "1" } - - event: event.2 - data: { "IntValue": 2, "StringValue": "2" } - - event: done - data: [DONE] - - - """; - - public bool ProtocolMethodCalled { get; private set; } - - // mock convenience method - public virtual AsyncCollectionResult GetModelsStreamingAsync(string content = DefaultMockContent) - { - return new AsyncMockJsonModelCollection(content, GetModelsStreamingAsync); - } - - // mock protocol method - public virtual ClientResult GetModelsStreamingAsync(string content, RequestOptions? options = default) - { - // This mocks sending a request and returns a respose containing - // the passed-in content in the content stream. - - MockPipelineResponse response = new(); - response.SetContent(content); - - ProtocolMethodCalled = true; - - return ClientResult.FromResponse(response); - } - - // Internal client implementation of convenience-layer AsyncResultCollection. - // This currently layers over an internal AsyncResultCollection - // representing the event.data values, but does not strictly have to. - private class AsyncMockJsonModelCollection : AsyncCollectionResult - { - private readonly string _content; - private readonly Func _protocolMethod; - - public AsyncMockJsonModelCollection(string content, Func protocolMethod) - { - _content = content; - _protocolMethod = protocolMethod; - } - - public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - async Task getResultAsync() - { - await Task.Delay(0, cancellationToken); - return _protocolMethod(_content, /*options:*/ default); - } - - return new AsyncMockJsonModelEnumerator(getResultAsync, this, cancellationToken); - } - - private sealed class AsyncMockJsonModelEnumerator : IAsyncEnumerator - { - private const string _terminalData = "[DONE]"; - - private readonly Func> _getResultAsync; - private readonly AsyncMockJsonModelCollection _enumerable; - private readonly CancellationToken _cancellationToken; - - private IAsyncEnumerator? _events; - private MockJsonModel? _current; - - private bool _started; - - public AsyncMockJsonModelEnumerator(Func> getResultAsync, AsyncMockJsonModelCollection enumerable, CancellationToken cancellationToken) - { - Debug.Assert(getResultAsync is not null); - Debug.Assert(enumerable is not null); - - _getResultAsync = getResultAsync!; - _enumerable = enumerable!; - _cancellationToken = cancellationToken; - } - - MockJsonModel IAsyncEnumerator.Current - => _current!; - - async ValueTask IAsyncEnumerator.MoveNextAsync() - { - if (_events is null && _started) - { - throw new ObjectDisposedException(nameof(AsyncMockJsonModelEnumerator)); - } - - _cancellationToken.ThrowIfCancellationRequested(); - _events ??= await CreateEventEnumeratorAsync().ConfigureAwait(false); - _started = true; - - if (await _events.MoveNextAsync().ConfigureAwait(false)) - { - if (_events.Current.Data == _terminalData) - { - _current = default; - return false; - } - - BinaryData data = BinaryData.FromString(_events.Current.Data); - MockJsonModel model = ModelReaderWriter.Read(data) ?? - throw new JsonException($"Failed to deserialize expected type MockJsonModel from sse data payload '{_events.Current.Data}'."); - - _current = model; - return true; - } - - _current = default; - return false; - } - - private async Task> CreateEventEnumeratorAsync() - { - ClientResult result = await _getResultAsync().ConfigureAwait(false); - PipelineResponse response = result.GetRawResponse(); - _enumerable.SetRawResponse(response); - - if (response.ContentStream is null) - { - throw new ArgumentException("Unable to create result from response with null ContentStream", nameof(response)); - } - - AsyncServerSentEventEnumerable enumerable = new(response.ContentStream); - return enumerable.GetAsyncEnumerator(_cancellationToken); - } - - public async ValueTask DisposeAsync() - { - await DisposeAsyncCore().ConfigureAwait(false); - - GC.SuppressFinalize(this); - } - - private async ValueTask DisposeAsyncCore() - { - if (_events is not null) - { - // Disposing the sse enumerator should be a no-op. - await _events.DisposeAsync().ConfigureAwait(false); - _events = null; - - // But we also need to dispose the response content stream - // so we don't leave the unbuffered network stream open. - PipelineResponse response = _enumerable.GetRawResponse(); - - if (response.ContentStream is IAsyncDisposable asyncDisposable) - { - await asyncDisposable.DisposeAsync().ConfigureAwait(false); - } - else if (response.ContentStream is IDisposable disposable) - { - disposable.Dispose(); - } - } - } - } - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs deleted file mode 100644 index ac493b7abd64..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Internal; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Internal.Mocks; - -public static class MockSseClientExtensions -{ - public static AsyncCollectionResult EnumerateDataEvents(this PipelineResponse response) - { - if (response.ContentStream is null) - { - throw new ArgumentException("Unable to create result collection from PipelineResponse with null ContentStream", nameof(response)); - } - - return new AsyncSseDataEventCollection(response, "[DONE]"); - } - - private class AsyncSseDataEventCollection : AsyncCollectionResult - { - private readonly string _terminalData; - - public AsyncSseDataEventCollection(PipelineResponse response, string terminalData) : base(response) - { - Argument.AssertNotNull(response, nameof(response)); - - _terminalData = terminalData; - } - - public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - PipelineResponse response = GetRawResponse(); - - // We validate that response.ContentStream is non-null in outer extension method. - Debug.Assert(response.ContentStream is not null); - - return new AsyncSseDataEventEnumerator(response.ContentStream!, _terminalData, cancellationToken); - } - - private sealed class AsyncSseDataEventEnumerator : IAsyncEnumerator - { - private readonly string _terminalData; - - private IAsyncEnumerator? _events; - private BinaryData? _current; - - public BinaryData Current { get => _current!; } - - public AsyncSseDataEventEnumerator(Stream contentStream, string terminalData, CancellationToken cancellationToken) - { - Debug.Assert(contentStream is not null); - - AsyncServerSentEventEnumerable enumerable = new(contentStream!); - _events = enumerable.GetAsyncEnumerator(cancellationToken); - - _terminalData = terminalData; - } - - public async ValueTask MoveNextAsync() - { - if (_events is null) - { - throw new ObjectDisposedException(nameof(AsyncSseDataEventEnumerator)); - } - - if (await _events.MoveNextAsync().ConfigureAwait(false)) - { - if (_events.Current.Data == _terminalData) - { - _current = default; - return false; - } - - _current = BinaryData.FromString(_events.Current.Data); - return true; - } - - _current = default; - return false; - } - - public async ValueTask DisposeAsync() - { - await DisposeAsyncCore().ConfigureAwait(false); - - GC.SuppressFinalize(this); - } - - private async ValueTask DisposeAsyncCore() - { - if (_events is not null) - { - await _events.DisposeAsync().ConfigureAwait(false); - _events = null; - } - } - } - } -} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs deleted file mode 100644 index 8ed090f78a22..000000000000 --- a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Threading.Tasks; - -namespace ClientModel.Tests.Internal.Mocks; - -internal static class MockSyncAsyncInternalExtensions -{ - public static async Task TryGetNextEventSyncOrAsync(this ServerSentEventReader reader, bool isAsync) - { - if (isAsync) - { - return await reader.TryGetNextEventAsync().ConfigureAwait(false); - } - else - { - return reader.TryGetNextEvent(); - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj b/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj index b1440d99b464..b8070ffcf523 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj +++ b/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj @@ -37,7 +37,10 @@ - + + + + @@ -53,6 +56,8 @@ + + diff --git a/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj b/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj index dc92fd798ff7..7c3d3057d282 100644 --- a/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj +++ b/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj @@ -19,6 +19,8 @@ + + diff --git a/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj b/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj index e6934e2923ce..fd0abb7300a3 100644 --- a/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj +++ b/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj @@ -16,6 +16,8 @@ + +