diff --git a/sdk/core/System.ClientModel/CHANGELOG.md b/sdk/core/System.ClientModel/CHANGELOG.md index 622aa9d4d41b..7eab2e617278 100644 --- a/sdk/core/System.ClientModel/CHANGELOG.md +++ b/sdk/core/System.ClientModel/CHANGELOG.md @@ -6,6 +6,7 @@ - Added `BufferResponse` property to `RequestOptions` so protocol method callers can turn off response buffering if desired. - Added `AsyncResultCollection` and `ResultCollection` for clients to return from service methods where the service response contains a collection of values. +- Added `AsyncPageableCollection`, `PageableCollection` and `ResultPage` for clients to return from service methods where collection values are delivered to the client over one or more service responses. - Added `SetRawResponse` method to `ClientResult` to allow the response held by the result to be changed, for example by derived types that obtain multiple responses from polling the service. ### Breaking Changes diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 767ef0f2dfae..53cdd65eb9a9 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -7,6 +7,12 @@ public ApiKeyCredential(string key) { } public static implicit operator System.ClientModel.ApiKeyCredential (string key) { throw null; } public void Update(string key) { } } + public abstract partial class AsyncPageableCollection : System.ClientModel.AsyncResultCollection + { + protected AsyncPageableCollection() { } + public abstract System.Collections.Generic.IAsyncEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = default(int?)); + public override System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + } public abstract partial class AsyncResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable { protected internal AsyncResultCollection() { } @@ -48,6 +54,12 @@ protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineR public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } + public abstract partial class PageableCollection : System.ClientModel.ResultCollection + { + protected PageableCollection() { } + public abstract System.Collections.Generic.IEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = default(int?)); + public override System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + } public abstract partial class ResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable { protected internal ResultCollection() { } @@ -55,6 +67,13 @@ protected internal ResultCollection(System.ClientModel.Primitives.PipelineRespon public abstract System.Collections.Generic.IEnumerator GetEnumerator(); System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } } + public partial class ResultPage : System.ClientModel.ResultCollection + { + internal ResultPage() { } + public string? ContinuationToken { get { throw null; } } + public static System.ClientModel.ResultPage Create(System.Collections.Generic.IEnumerable values, string? continuationToken, System.ClientModel.Primitives.PipelineResponse response) { throw null; } + public override System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + } } namespace System.ClientModel.Primitives { diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index c3303e8be4d7..e43b75c72c1e 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -7,6 +7,12 @@ public ApiKeyCredential(string key) { } public static implicit operator System.ClientModel.ApiKeyCredential (string key) { throw null; } public void Update(string key) { } } + public abstract partial class AsyncPageableCollection : System.ClientModel.AsyncResultCollection + { + protected AsyncPageableCollection() { } + public abstract System.Collections.Generic.IAsyncEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = default(int?)); + public override System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + } public abstract partial class AsyncResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable { protected internal AsyncResultCollection() { } @@ -48,6 +54,12 @@ protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineR public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } + public abstract partial class PageableCollection : System.ClientModel.ResultCollection + { + protected PageableCollection() { } + public abstract System.Collections.Generic.IEnumerable> AsPages(string? continuationToken = null, int? pageSizeHint = default(int?)); + public override System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + } public abstract partial class ResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable { protected internal ResultCollection() { } @@ -55,6 +67,13 @@ protected internal ResultCollection(System.ClientModel.Primitives.PipelineRespon public abstract System.Collections.Generic.IEnumerator GetEnumerator(); System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } } + public partial class ResultPage : System.ClientModel.ResultCollection + { + internal ResultPage() { } + public string? ContinuationToken { get { throw null; } } + public static System.ClientModel.ResultPage Create(System.Collections.Generic.IEnumerable values, string? continuationToken, System.ClientModel.Primitives.PipelineResponse response) { throw null; } + public override System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } + } } namespace System.ClientModel.Primitives { diff --git a/sdk/core/System.ClientModel/src/Convenience/AsyncPageableCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/AsyncPageableCollectionOfT.cs new file mode 100644 index 000000000000..1a36bab43a6b --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/AsyncPageableCollectionOfT.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.ClientModel; + +/// +/// Represents a collection of results returned from a cloud service operation +/// sequentially over one or more calls to the service. +/// +public abstract class AsyncPageableCollection : AsyncResultCollection +{ + /// + /// Create a new instance of . + /// + /// This constructor does not take a + /// because derived types are expected to defer the first service call + /// until the collection is enumerated using await foreach. + /// + protected AsyncPageableCollection() : base() + { + } + + /// + /// Return an enumerable of that aynchronously + /// enumerates the collection's pages instead of the collection's individual + /// values. This may make multiple service requests. + /// + /// A token indicating where the collection + /// of results returned from the service should begin. Passing null + /// will start the collection at the first page of values. + /// The number of items to request that the + /// service return in a , if the service supports + /// such requests. + /// An async sequence of , each holding + /// the subset of collection values contained in a given service response. + /// + public abstract IAsyncEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default); + + /// + /// Return an enumerator that iterates asynchronously through the collection + /// values. This may make multiple service requests. + /// + /// The used + /// with requests made while enumerating asynchronously. + /// An that can iterate + /// asynchronously through the collection values. + public override async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + await foreach (ResultPage page in AsPages().ConfigureAwait(false).WithCancellation(cancellationToken)) + { + foreach (T value in page) + { + yield return value; + } + } + } +} diff --git a/sdk/core/System.ClientModel/src/Convenience/PageableCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/PageableCollectionOfT.cs new file mode 100644 index 000000000000..eff5d4f5e0c8 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/PageableCollectionOfT.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; + +namespace System.ClientModel; + +/// +/// Represents a collection of results returned from a cloud service operation +/// sequentially over one or more calls to the service. +/// +public abstract class PageableCollection : ResultCollection +{ + /// + /// Create a new instance of . + /// + /// This constructor does not take a + /// because derived types are expected to defer the first service call + /// until the collection is enumerated using foreach. + protected PageableCollection() : base() + { + } + + /// + /// Return an enumerable of that enumerates the + /// collection's pages instead of the collection's individual values. This + /// may make multiple service requests. + /// + /// A token indicating where the collection + /// of results returned from the service should begin. Passing null + /// will start the collection at the first page of values. + /// The number of items to request that the + /// service return in a , if the service supports + /// such requests. + /// A sequence of , each holding the + /// subset of collection values contained in a given service response. + /// + public abstract IEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default); + + /// + /// Return an enumerator that iterates through the collection values. This + /// may make multiple service requests. + /// + /// An that can iterate through the + /// collection values. + public override IEnumerator GetEnumerator() + { + foreach (ResultPage page in AsPages()) + { + foreach (T value in page) + { + yield return value; + } + } + } +} diff --git a/sdk/core/System.ClientModel/src/Convenience/ResultPageOfT.cs b/sdk/core/System.ClientModel/src/Convenience/ResultPageOfT.cs new file mode 100644 index 000000000000..9542295127ca --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/ResultPageOfT.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; + +namespace System.ClientModel; + +/// +/// Represents the subset (or page) of results contained in a single response +/// from a cloud service returning a collection of results sequentially over +/// one or more calls to the service (i.e. a paged collection). +/// +public class ResultPage : ResultCollection +{ + private readonly IEnumerable _values; + + private ResultPage(IEnumerable values, string? continuationToken, PipelineResponse response) + : base(response) + { + _values = values; + ContinuationToken = continuationToken; + } + + /// + /// Creates a new . + /// + /// The values contained in . + /// + /// The token that can be used to request + /// the next page of results from the service, or null if this page + /// holds the final subset of values. + /// The holding the + /// collection values returned by the service. + /// An instance of holding the provided + /// values. + public static ResultPage Create(IEnumerable values, string? continuationToken, PipelineResponse response) + => new(values, continuationToken, response); + + /// + /// Gets the continuation token used to request the next + /// . May be null or empty when no values + /// remain to be returned from the collection. + /// + public string? ContinuationToken { get; } + + /// + public override IEnumerator GetEnumerator() + => _values.GetEnumerator(); +} diff --git a/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs b/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs index 629f1ec27f2b..2c592dce0d98 100644 --- a/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs +++ b/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs @@ -8,7 +8,7 @@ namespace System.ClientModel.Tests.Results; -public class PipelineResponseTests +public class ClientResultTests { #region ClientResult diff --git a/sdk/core/System.ClientModel/tests/Convenience/PageableCollectionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/PageableCollectionTests.cs new file mode 100644 index 000000000000..48092dd8b27d --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Convenience/PageableCollectionTests.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Mocks; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Results; + +public class PageableCollectionTests +{ + private static readonly string[] MockPageContents = { """ + [ + { "intValue" : 0, "stringValue" : "0" }, + { "intValue" : 1, "stringValue" : "1" }, + { "intValue" : 2, "stringValue" : "2" } + ] + """,""" + [ + { "intValue" : 3, "stringValue" : "3" }, + { "intValue" : 4, "stringValue" : "4" }, + { "intValue" : 5, "stringValue" : "5" } + ] + """,""" + [ + { "intValue" : 6, "stringValue" : "6" }, + { "intValue" : 7, "stringValue" : "7" }, + { "intValue" : 8, "stringValue" : "8" } + ] + """, + }; + + private static readonly int PageCount = MockPageContents.Length; + private static readonly int ItemCount = 9; + + [Test] + public void CanEnumerateValues() + { + MockPageableClient client = new(); + PageableCollection models = client.GetModels(MockPageContents); + + int i = 0; + foreach (MockJsonModel model in models) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + Assert.AreEqual(ItemCount, i); + } + + [Test] + public void CanEnumeratePages() + { + MockPageableClient client = new(); + PageableCollection models = client.GetModels(MockPageContents); + + int pageCount = 0; + int itemCount = 0; + foreach (ResultPage page in models.AsPages()) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(itemCount, model.IntValue); + Assert.AreEqual(itemCount.ToString(), model.StringValue); + + itemCount++; + } + + pageCount++; + } + + Assert.AreEqual(ItemCount, itemCount); + Assert.AreEqual(PageCount, pageCount); + } + + [Test] + public void CanStartPageEnumerationMidwayThrough() + { + MockPageableClient client = new(); + PageableCollection models = client.GetModels(MockPageContents); + + int pageCount = 0; + int i = 6; + + // Request just the last page by starting at the last seen value + // on the prior page -- i.e. item 5. + foreach (ResultPage page in models.AsPages(continuationToken: "5")) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + pageCount++; + } + + Assert.AreEqual(ItemCount, i); + Assert.AreEqual(1, pageCount); + } + + [Test] + public void CanSetPageSizeHint() + { + MockPageableClient client = new(); + PageableCollection models = client.GetModels(MockPageContents); + var pages = models.AsPages(pageSizeHint: 10); + foreach (var _ in pages) + { + // page size hint is ignored in this mock + } + + Assert.AreEqual(10, client.RequestedPageSize); + } + + [Test] + public void CanGetRawResponses() + { + MockPageableClient client = new(); + PageableCollection models = client.GetModels(MockPageContents); + + int pageCount = 0; + int itemCount = 0; + foreach (ResultPage page in models.AsPages()) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(itemCount, model.IntValue); + Assert.AreEqual(itemCount.ToString(), model.StringValue); + + itemCount++; + } + + PipelineResponse collectionResponse = models.GetRawResponse(); + PipelineResponse pageResponse = page.GetRawResponse(); + + Assert.AreEqual(pageResponse, collectionResponse); + Assert.AreEqual(MockPageContents[pageCount], pageResponse.Content.ToString()); + Assert.AreEqual(MockPageContents[pageCount], collectionResponse.Content.ToString()); + + pageCount++; + } + + Assert.AreEqual(ItemCount, itemCount); + Assert.AreEqual(PageCount, pageCount); + } + + [Test] + public async Task CanEnumerateValuesAsync() + { + MockPageableClient client = new(); + AsyncPageableCollection models = client.GetModelsAsync(MockPageContents); + + int i = 0; + await foreach (MockJsonModel model in models) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + Assert.AreEqual(ItemCount, i); + } + + [Test] + public async Task CanEnumeratePagesAsync() + { + MockPageableClient client = new(); + AsyncPageableCollection models = client.GetModelsAsync(MockPageContents); + + int pageCount = 0; + int itemCount = 0; + await foreach (ResultPage page in models.AsPages()) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(itemCount, model.IntValue); + Assert.AreEqual(itemCount.ToString(), model.StringValue); + + itemCount++; + } + + pageCount++; + } + + Assert.AreEqual(ItemCount, itemCount); + Assert.AreEqual(PageCount, pageCount); + } + + [Test] + public async Task CanStartPageEnumerationMidwayThroughAsync() + { + MockPageableClient client = new(); + AsyncPageableCollection models = client.GetModelsAsync(MockPageContents); + + int pageCount = 0; + int i = 6; + + // Request just the last page by starting at the last seen value + // on the prior page -- i.e. item 5. + await foreach (ResultPage page in models.AsPages(continuationToken: "5")) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + pageCount++; + } + + Assert.AreEqual(ItemCount, i); + Assert.AreEqual(1, pageCount); + } + + [Test] + public async Task CanSetPageSizeHintAsync() + { + MockPageableClient client = new(); + AsyncPageableCollection models = client.GetModelsAsync(MockPageContents); + var pages = models.AsPages(pageSizeHint: 10); + await foreach (var _ in pages) + { + // page size hint is ignored in this mock + } + + Assert.AreEqual(10, client.RequestedPageSize); + } + + [Test] + public async Task CanGetRawResponsesAsync() + { + MockPageableClient client = new(); + AsyncPageableCollection models = client.GetModelsAsync(MockPageContents); + + int pageCount = 0; + int itemCount = 0; + await foreach (ResultPage page in models.AsPages()) + { + foreach (MockJsonModel model in page) + { + Assert.AreEqual(itemCount, model.IntValue); + Assert.AreEqual(itemCount.ToString(), model.StringValue); + + itemCount++; + } + + PipelineResponse collectionResponse = models.GetRawResponse(); + PipelineResponse pageResponse = page.GetRawResponse(); + + Assert.AreEqual(pageResponse, collectionResponse); + Assert.AreEqual(MockPageContents[pageCount], pageResponse.Content.ToString()); + Assert.AreEqual(MockPageContents[pageCount], collectionResponse.Content.ToString()); + + pageCount++; + } + + Assert.AreEqual(ItemCount, itemCount); + Assert.AreEqual(PageCount, pageCount); + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/JsonModelList.cs b/sdk/core/System.ClientModel/tests/TestFramework/JsonModelList.cs new file mode 100644 index 000000000000..4ce5fd591436 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/JsonModelList.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace ClientModel.Tests.Internal; + +internal class JsonModelList : List, IJsonModel> + where TModel : IJsonModel +{ + public JsonModelList Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel>)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(JsonModelList)} does not support reading '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return DeserializeJsonModelList(document.RootElement, options); + } + + public JsonModelList Create(BinaryData data, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel>)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return DeserializeJsonModelList(document.RootElement, options); + } + default: + throw new FormatException($"The model {nameof(JsonModelList)} does not support reading '{options.Format}' format."); + } + } + + internal static JsonModelList DeserializeJsonModelList(JsonElement element, ModelReaderWriterOptions? options = null) + { + options ??= new ModelReaderWriterOptions("W"); + + if (element.ValueKind != JsonValueKind.Array) + { + throw new InvalidOperationException("Cannot deserialize JsonModelList from JSON that is not an array."); + } + + JsonModelList list = new(); + + foreach (JsonElement item in element.EnumerateArray()) + { + // TODO: Make efficient + TModel? value = ModelReaderWriter.Read(BinaryData.FromString(item.ToString()), options) ?? + throw new InvalidOperationException("Failed to deserialized array element."); + list.Add(value); + } + + return list; + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel>)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(JsonModelList)} does not support writing '{format}' format."); + } + + writer.WriteStartArray(); + + foreach (IJsonModel item in this) + { + item.Write(writer, options); + } + + writer.WriteEndArray(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel>)this).GetFormatFromOptions(options) : options.Format; + + return format switch + { + "J" => ModelReaderWriter.Write(this, options), + _ => throw new FormatException($"The model {nameof(JsonModelList)} does not support writing '{options.Format}' format."), + }; + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageableClient.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageableClient.cs new file mode 100644 index 000000000000..0ad116f70c94 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPageableClient.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Diagnostics; +using System.Linq; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Internal; + +namespace ClientModel.Tests.Mocks; + +public class MockPageableClient +{ + public bool ProtocolMethodCalled { get; private set; } + public int? RequestedPageSize { get; private set; } + + // mock convenience method - async + public virtual AsyncPageableCollection GetModelsAsync(string[] pageContents) + { + PipelineResponse? lastResponse = default; + + // The contract for this pageable implementation is that the last seen + // value id (where the id is StringValue) provides the continuation token + // for the page. + + int pageNumber = 0; + JsonModelList values = new(); + + async Task> firstPageFuncAsync(int? pageSize) + { + ClientResult result = await GetModelsAsync(pageContents[pageNumber++], options: null).ConfigureAwait(false); + lastResponse = result.GetRawResponse(); + values = ModelReaderWriter.Read>(lastResponse.Content)!; + string? continuationToken = pageNumber < pageContents.Length ? values[values.Count - 1].StringValue : null; + return ResultPage.Create(values, continuationToken, lastResponse); + } + + async Task> nextPageFuncAsync(string? continuationToken, int? pageSize) + { + RequestedPageSize = pageSize; + + bool atRequestedPage = values.Count > 0 && values.Last().StringValue == continuationToken; + while (!atRequestedPage && pageNumber < pageContents.Length) + { + BinaryData content = BinaryData.FromString(pageContents[pageNumber++]); + JsonModelList pageValues = ModelReaderWriter.Read>(content)!; + atRequestedPage = pageValues[pageValues.Count - 1].StringValue == continuationToken; + } + + Debug.Assert(atRequestedPage is true); + + ClientResult result = await GetModelsAsync(pageContents[pageNumber++], options: null).ConfigureAwait(false); + lastResponse = result.GetRawResponse(); + values = ModelReaderWriter.Read>(lastResponse.Content)!; + continuationToken = pageNumber < pageContents.Length ? values[values.Count - 1].StringValue : null; + return ResultPage.Create(values, continuationToken, lastResponse); + } + + return PageableResultHelpers.Create(firstPageFuncAsync, nextPageFuncAsync); + } + + // mock convenience method - sync + public virtual PageableCollection GetModels(string[] pageContents) + { + PipelineResponse? lastResponse = default; + + // The contract for this pageable implementation is that the last seen + // value id (where the id is StringValue) provides the continuation token + // for the page. + + int pageNumber = 0; + JsonModelList values = new(); + + ResultPage firstPageFunc(int? pageSize) + { + ClientResult result = GetModels(pageContents[pageNumber++], options: null); + lastResponse = result.GetRawResponse(); + values = ModelReaderWriter.Read>(lastResponse.Content)!; + string? continuationToken = pageNumber < pageContents.Length ? values[values.Count - 1].StringValue : null; + return ResultPage.Create(values, continuationToken, lastResponse); + } + + ResultPage nextPageFunc(string? continuationToken, int? pageSize) + { + RequestedPageSize = pageSize; + + bool atRequestedPage = values.Count > 0 && values.Last().StringValue == continuationToken; + while (!atRequestedPage && pageNumber < pageContents.Length) + { + BinaryData content = BinaryData.FromString(pageContents[pageNumber++]); + JsonModelList pageValues = ModelReaderWriter.Read>(content)!; + atRequestedPage = pageValues[pageValues.Count - 1].StringValue == continuationToken; + } + + Debug.Assert(atRequestedPage is true); + + ClientResult result = GetModels(pageContents[pageNumber++], options: null); + lastResponse = result.GetRawResponse(); + values = ModelReaderWriter.Read>(lastResponse.Content)!; + continuationToken = pageNumber < pageContents.Length ? values[values.Count - 1].StringValue : null; + return ResultPage.Create(values, continuationToken, lastResponse); + } + + return PageableResultHelpers.Create(firstPageFunc, nextPageFunc); + } + + // mock protocol method - async + public virtual async Task GetModelsAsync(string pageContent, RequestOptions? options = default) + { + await Task.Delay(0); + + MockPipelineResponse response = new(200); + response.SetContent(pageContent); + + ProtocolMethodCalled = true; + + return ClientResult.FromResponse(response); + } + + // mock protocol method - sync + public virtual ClientResult GetModels(string pageContent, RequestOptions? options = default) + { + MockPipelineResponse response = new(200); + response.SetContent(pageContent); + + ProtocolMethodCalled = true; + + return ClientResult.FromResponse(response); + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/PageableResultHelpers.cs b/sdk/core/System.ClientModel/tests/TestFramework/PageableResultHelpers.cs new file mode 100644 index 000000000000..6013ede13cdc --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/PageableResultHelpers.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Internal; + +internal class PageableResultHelpers +{ + public static PageableCollection Create(Func> firstPageFunc, Func>? nextPageFunc, int? pageSize = default) where T : notnull + { + ResultPage first(string? _, int? pageSizeHint) => firstPageFunc(pageSizeHint); + return new FuncPageable(first, nextPageFunc, pageSize); + } + + public static AsyncPageableCollection Create(Func>> firstPageFunc, Func>>? nextPageFunc, int? pageSize = default) where T : notnull + { + Task> first(string? _, int? pageSizeHint) => firstPageFunc(pageSizeHint); + return new FuncAsyncPageable(first, nextPageFunc, pageSize); + } + + private class FuncAsyncPageable : AsyncPageableCollection where T : notnull + { + private readonly Func>> _firstPageFunc; + private readonly Func>>? _nextPageFunc; + private readonly int? _defaultPageSize; + + public FuncAsyncPageable(Func>> firstPageFunc, Func>>? nextPageFunc, int? defaultPageSize = default) + { + _firstPageFunc = firstPageFunc; + _nextPageFunc = nextPageFunc; + _defaultPageSize = defaultPageSize; + } + + public override async IAsyncEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default) + { + Func>>? pageFunc = string.IsNullOrEmpty(continuationToken) ? _firstPageFunc : _nextPageFunc; + + if (pageFunc == null) + { + yield break; + } + + int? pageSize = pageSizeHint ?? _defaultPageSize; + do + { + ResultPage page = await pageFunc(continuationToken, pageSize).ConfigureAwait(false); + SetRawResponse(page.GetRawResponse()); + yield return page; + continuationToken = page.ContinuationToken; + pageFunc = _nextPageFunc; + } + while (!string.IsNullOrEmpty(continuationToken) && pageFunc != null); + } + } + + private class FuncPageable : PageableCollection where T : notnull + { + private readonly Func> _firstPageFunc; + private readonly Func>? _nextPageFunc; + private readonly int? _defaultPageSize; + + public FuncPageable(Func> firstPageFunc, Func>? nextPageFunc, int? defaultPageSize = default) + { + _firstPageFunc = firstPageFunc; + _nextPageFunc = nextPageFunc; + _defaultPageSize = defaultPageSize; + } + + public override IEnumerable> AsPages(string? continuationToken = default, int? pageSizeHint = default) + { + Func>? pageFunc = string.IsNullOrEmpty(continuationToken) ? _firstPageFunc : _nextPageFunc; + + if (pageFunc == null) + { + yield break; + } + + int? pageSize = pageSizeHint ?? _defaultPageSize; + do + { + ResultPage page = pageFunc(continuationToken, pageSize); + SetRawResponse(page.GetRawResponse()); + yield return page; + continuationToken = page.ContinuationToken; + pageFunc = _nextPageFunc; + } + while (!string.IsNullOrEmpty(continuationToken) && pageFunc != null); + } + } +}