diff --git a/sdk/core/Azure.Core/tests/HttpPipelineTests.cs b/sdk/core/Azure.Core/tests/HttpPipelineTests.cs index 4c297eefd861..63e39cfeb8f4 100644 --- a/sdk/core/Azure.Core/tests/HttpPipelineTests.cs +++ b/sdk/core/Azure.Core/tests/HttpPipelineTests.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -334,6 +336,197 @@ public async Task RequestContextDefault_IsErrorIsSet(int code, bool isError) Assert.AreEqual(isError, response.IsError); } + [Test] + public async Task AzurePolicyInClientModelPipeline() + { + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport(), + }; + + ClientPipeline pipeline = ClientPipeline.Create(options, + perCallPolicies: new[] { new DoNothingPolicy() }, + perTryPolicies: ReadOnlySpan.Empty, + beforeTransportPolicies: ReadOnlySpan.Empty); + + using PipelineMessage message = new HttpMessage(new MemoryRequest(), ResponseClassifier.Shared); + await pipeline.SendAsync(message); + } + + private class DoNothingPolicy : HttpPipelinePolicy + { + public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + => pipeline.Span[0].ProcessAsync(message, pipeline.Slice(1)); + + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + => pipeline.Span[0].Process(message, pipeline.Slice(1)); + } + + private class MockPipelineTransport : PipelineTransport + { + protected override PipelineMessage CreateMessageCore() + => new HttpMessage(new MemoryRequest(), ResponseClassifier.Shared); + + protected override void ProcessCore(PipelineMessage message) + { + var response = new MemoryResponse(); + response.SetStatus(200); + ((HttpMessage)message).Response = response; + } + + protected override ValueTask ProcessCoreAsync(PipelineMessage message) + { + ProcessCore(message); + return new ValueTask(); + } + } + + [Test] + public async Task ClientModelPolicyInAzurePipeline() + { + var pipeline = HttpPipelineBuilder.Build(new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) }); + + var context = new RequestContext(); + context.AddPolicy(new ReplaceResponseClassifierPipelinePolicy(), PipelinePosition.PerCall); + + using HttpMessage message = pipeline.CreateMessage(context); + await pipeline.SendAsync(message, message.CancellationToken); + Assert.IsFalse(message.Response.IsError); + } + + [Test] + public async Task ClientModelPolicyWrappedForAzurePipeline() + { + var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) }; + options.AddPolicy(new ClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall); + var pipeline = HttpPipelineBuilder.Build(options); + + var context = new RequestContext(); + using HttpMessage message = pipeline.CreateMessage(context); + await pipeline.SendAsync(message, message.CancellationToken); + Assert.IsFalse(message.Response.IsError); + } + + [Test] + public async Task ClientModelPolicyWrappedForAzurePipelineV2() + { + var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) }; + options.AddPolicy(new AdvancedClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall); + var pipeline = HttpPipelineBuilder.Build(options); + + var context = new RequestContext(); + using HttpMessage message = pipeline.CreateMessage(context); + await pipeline.SendAsync(message, message.CancellationToken); + Assert.IsFalse(message.Response.IsError); + } + + [Test] + public async Task ClientModelPolicyWrappedForAzurePipelineV3() + { + var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(400)) }; + options.AddPolicy(new AdvancedClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall); + var pipeline = HttpPipelineBuilder.Build(options); + + var context = new RequestContext(); + using HttpMessage message = pipeline.CreateMessage(context); + await pipeline.SendAsync(message, message.CancellationToken); + Assert.IsFalse(message.Response.IsError); + } + + private class CustomPipelineMessageClassifier : PipelineMessageClassifier + { + public override bool TryClassify(PipelineMessage message, out bool isError) + { + isError = !message.Response!.Status.Equals(404); + return !isError; + } + + public override bool TryClassify(PipelineMessage message, Exception exception, out bool isRetriable) + { + isRetriable = exception == null && message.Response != null && message.Response.Status.Equals(404); + return isRetriable; + } + } + + private class ReplaceResponseClassifierPipelinePolicy : PipelinePolicy + { + public override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + message.ResponseClassifier = new CustomPipelineMessageClassifier(); + ProcessNext(message, pipeline, currentIndex); + } + + public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + message.ResponseClassifier = new CustomPipelineMessageClassifier(); + await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); + } + } + + private class ClientPolicyWrapper : HttpPipelinePolicy + { + private readonly PipelinePolicy _policy; + + public ClientPolicyWrapper(PipelinePolicy policy) + { + _policy = policy; + } + + public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + await _policy.ProcessAsync(message, pipeline.Slice(1).ToArray(), 0).ConfigureAwait(false); + } + + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + _policy.Process(message, pipeline.Slice(1).ToArray(), 0); + } + } + + private class AdvancedClientPolicyWrapper : HttpPipelinePolicy + { + private readonly PipelinePolicy _policy; + + public AdvancedClientPolicyWrapper(PipelinePolicy policy) + { + _policy = policy; + } + + public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + await _policy.ProcessAsync(message, new[] { null, new Shim(message, pipeline) }, 0).ConfigureAwait(false); + } + + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + _policy.Process(message, new[] { null, new Shim(message, pipeline) }, 0); + } + + private class Shim : PipelinePolicy + { + private readonly HttpMessage _message; + private readonly ReadOnlyMemory _pipeline; + + public Shim(HttpMessage message, ReadOnlyMemory pipeline) + { + _message = message; + _pipeline = pipeline; + } + + public override ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + return _pipeline.Span[0].ProcessAsync(_message, _pipeline.Slice(1)); + } + + public override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + _pipeline.Span[0].Process(_message, _pipeline.Slice(1)); + } + } + } + + private class TestClientOptions : ClientOptions { } + #region Helpers public class AddHeaderPolicy : HttpPipelineSynchronousPolicy {