Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions sdk/core/Azure.Core/tests/HttpPipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -334,6 +336,197 @@ public async Task RequestContextDefault_IsErrorIsSet(int code, bool isError)
Assert.AreEqual(isError, response.IsError);
}

[Test]
public async Task AzurePolicyInClientModelPipeline()
{
ClientPipelineOptions options = new()
{
Transport = new MockPipelineTransport(),
};

ClientPipeline pipeline = ClientPipeline.Create(options,
perCallPolicies: new[] { new DoNothingPolicy() },
perTryPolicies: ReadOnlySpan<PipelinePolicy>.Empty,
beforeTransportPolicies: ReadOnlySpan<PipelinePolicy>.Empty);

using PipelineMessage message = new HttpMessage(new MemoryRequest(), ResponseClassifier.Shared);
await pipeline.SendAsync(message);
}

private class DoNothingPolicy : HttpPipelinePolicy
{
public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
=> pipeline.Span[0].ProcessAsync(message, pipeline.Slice(1));

public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
=> pipeline.Span[0].Process(message, pipeline.Slice(1));
}

private class MockPipelineTransport : PipelineTransport
{
protected override PipelineMessage CreateMessageCore()
=> new HttpMessage(new MemoryRequest(), ResponseClassifier.Shared);

protected override void ProcessCore(PipelineMessage message)
{
var response = new MemoryResponse();
response.SetStatus(200);
((HttpMessage)message).Response = response;
}

protected override ValueTask ProcessCoreAsync(PipelineMessage message)
{
ProcessCore(message);
return new ValueTask();
}
}

[Test]
public async Task ClientModelPolicyInAzurePipeline()
{
var pipeline = HttpPipelineBuilder.Build(new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) });

var context = new RequestContext();
context.AddPolicy(new ReplaceResponseClassifierPipelinePolicy(), PipelinePosition.PerCall);

using HttpMessage message = pipeline.CreateMessage(context);
await pipeline.SendAsync(message, message.CancellationToken);
Assert.IsFalse(message.Response.IsError);
}

[Test]
public async Task ClientModelPolicyWrappedForAzurePipeline()
{
var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) };
options.AddPolicy(new ClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall);
var pipeline = HttpPipelineBuilder.Build(options);

var context = new RequestContext();
using HttpMessage message = pipeline.CreateMessage(context);
await pipeline.SendAsync(message, message.CancellationToken);
Assert.IsFalse(message.Response.IsError);
}

[Test]
public async Task ClientModelPolicyWrappedForAzurePipelineV2()
{
var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(404)) };
options.AddPolicy(new AdvancedClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall);
var pipeline = HttpPipelineBuilder.Build(options);

var context = new RequestContext();
using HttpMessage message = pipeline.CreateMessage(context);
await pipeline.SendAsync(message, message.CancellationToken);
Assert.IsFalse(message.Response.IsError);
}

[Test]
public async Task ClientModelPolicyWrappedForAzurePipelineV3()
{
var options = new TestClientOptions { Transport = new MockTransport(new MockResponse(400)) };
options.AddPolicy(new AdvancedClientPolicyWrapper(new ReplaceResponseClassifierPipelinePolicy()), HttpPipelinePosition.PerCall);
var pipeline = HttpPipelineBuilder.Build(options);

var context = new RequestContext();
using HttpMessage message = pipeline.CreateMessage(context);
await pipeline.SendAsync(message, message.CancellationToken);
Assert.IsFalse(message.Response.IsError);
}

private class CustomPipelineMessageClassifier : PipelineMessageClassifier
{
public override bool TryClassify(PipelineMessage message, out bool isError)
{
isError = !message.Response!.Status.Equals(404);
return !isError;
}

public override bool TryClassify(PipelineMessage message, Exception exception, out bool isRetriable)
{
isRetriable = exception == null && message.Response != null && message.Response.Status.Equals(404);
return isRetriable;
}
}

private class ReplaceResponseClassifierPipelinePolicy : PipelinePolicy
{
public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
message.ResponseClassifier = new CustomPipelineMessageClassifier();
ProcessNext(message, pipeline, currentIndex);
}

public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
message.ResponseClassifier = new CustomPipelineMessageClassifier();
await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false);
}
}

private class ClientPolicyWrapper : HttpPipelinePolicy
{
private readonly PipelinePolicy _policy;

public ClientPolicyWrapper(PipelinePolicy policy)
{
_policy = policy;
}

public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
{
await _policy.ProcessAsync(message, pipeline.Slice(1).ToArray(), 0).ConfigureAwait(false);
}

public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
{
_policy.Process(message, pipeline.Slice(1).ToArray(), 0);
}
}

private class AdvancedClientPolicyWrapper : HttpPipelinePolicy
{
private readonly PipelinePolicy _policy;

public AdvancedClientPolicyWrapper(PipelinePolicy policy)
{
_policy = policy;
}

public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
{
await _policy.ProcessAsync(message, new[] { null, new Shim(message, pipeline) }, 0).ConfigureAwait(false);
}

public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
{
_policy.Process(message, new[] { null, new Shim(message, pipeline) }, 0);
}

private class Shim : PipelinePolicy
{
private readonly HttpMessage _message;
private readonly ReadOnlyMemory<HttpPipelinePolicy> _pipeline;

public Shim(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
{
_message = message;
_pipeline = pipeline;
}

public override ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
return _pipeline.Span[0].ProcessAsync(_message, _pipeline.Slice(1));
}

public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
_pipeline.Span[0].Process(_message, _pipeline.Slice(1));
}
}
}

private class TestClientOptions : ClientOptions { }

#region Helpers
public class AddHeaderPolicy : HttpPipelineSynchronousPolicy
{
Expand Down