diff --git a/sdk/core/Azure.Core.TestFramework/src/MockResponse.cs b/sdk/core/Azure.Core.TestFramework/src/MockResponse.cs index 509b4a24b5b7..77e1e211fde5 100644 --- a/sdk/core/Azure.Core.TestFramework/src/MockResponse.cs +++ b/sdk/core/Azure.Core.TestFramework/src/MockResponse.cs @@ -34,12 +34,18 @@ public void SetContent(byte[] content) ContentStream = new MemoryStream(content, 0, content.Length, false, true); } - public void SetContent(string content) + public MockResponse SetContent(string content) { SetContent(Encoding.UTF8.GetBytes(content)); + return this; } - public void AddHeader(HttpHeader header) + public MockResponse AddHeader(string name, string value) + { + return AddHeader(new HttpHeader(name, value)); + } + + public MockResponse AddHeader(HttpHeader header) { if (!_headers.TryGetValue(header.Name, out List values)) { @@ -47,6 +53,7 @@ public void AddHeader(HttpHeader header) } values.Add(header.Value); + return this; } #if HAS_INTERNALS_VISIBLE_CORE diff --git a/sdk/core/Azure.Core/src/Diagnostics/AzureCoreEventSource.cs b/sdk/core/Azure.Core/src/Diagnostics/AzureCoreEventSource.cs index 151530be3855..09ea6fbd9c3e 100644 --- a/sdk/core/Azure.Core/src/Diagnostics/AzureCoreEventSource.cs +++ b/sdk/core/Azure.Core/src/Diagnostics/AzureCoreEventSource.cs @@ -14,6 +14,9 @@ internal sealed class AzureCoreEventSource : AzureEventSource private const int RequestEvent = 1; private const int RequestContentEvent = 2; private const int RequestContentTextEvent = 17; + private const int RequestRedirectEvent = 20; + private const int RequestRedirectBlockedEvent = 21; + private const int RequestRedirectCountExceededEvent = 22; private const int ResponseEvent = 5; private const int ResponseContentEvent = 6; private const int ResponseDelayEvent = 7; @@ -133,5 +136,23 @@ public void ExceptionResponse(string requestId, string exception) { WriteEvent(ExceptionResponseEvent, requestId, exception); } + + [Event(RequestRedirectEvent, Level = EventLevel.Verbose, Message = "Request [{0}] Redirecting from {1} to {2} in response to status code {3}")] + public void RequestRedirect(string requestId, string from, string to, int status) + { + WriteEvent(RequestRedirectEvent, requestId, from, to, status); + } + + [Event(RequestRedirectBlockedEvent, Level = EventLevel.Warning, Message = "Request [{0}] Insecure HTTPS to HTTP redirect from {1} to {2} was blocked.")] + public void RequestRedirectBlocked(string requestId, string from, string to) + { + WriteEvent(RequestRedirectBlockedEvent, requestId, from, to); + } + + [Event(RequestRedirectCountExceededEvent, Level = EventLevel.Warning, Message = "Request [{0}] Exceeded max number of redirects. Redirect from {1} to {2} blocked.")] + public void RequestRedirectCountExceeded(string requestId, string from, string to) + { + WriteEvent(RequestRedirectCountExceededEvent, requestId, from, to); + } } } diff --git a/sdk/core/Azure.Core/src/HttpMessage.cs b/sdk/core/Azure.Core/src/HttpMessage.cs index 75576665e880..21567d1310e5 100644 --- a/sdk/core/Azure.Core/src/HttpMessage.cs +++ b/sdk/core/Azure.Core/src/HttpMessage.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Net.Http; using System.Threading; using Azure.Core.Pipeline; diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs index d530f4431ca8..e51dd0eb5132 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs @@ -141,7 +141,7 @@ private static HttpClient CreateDefaultClient() return new HttpClient(httpMessageHandler) { // Timeouts are handled by the pipeline - Timeout = Timeout.InfiniteTimeSpan + Timeout = Timeout.InfiniteTimeSpan, }; } @@ -153,9 +153,15 @@ private static HttpMessageHandler CreateDefaultHandler() } #if NETCOREAPP - return new SocketsHttpHandler(); + return new SocketsHttpHandler() + { + AllowAutoRedirect = false + }; #else - return new HttpClientHandler(); + return new HttpClientHandler() + { + AllowAutoRedirect = false + }; #endif } diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs index a1232da4275a..7d6399da0b9d 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs @@ -65,6 +65,8 @@ public static HttpPipeline Build(ClientOptions options, HttpPipelinePolicy[] per RetryOptions retryOptions = options.Retry; policies.Add(new RetryPolicy(retryOptions.Mode, retryOptions.Delay, retryOptions.MaxDelay, retryOptions.MaxRetries)); + policies.Add(RedirectPolicy.Shared); + policies.AddRange(perRetryPolicies); policies.AddRange(options.PerRetryPolicies); diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs b/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs index b794f0da6ceb..17cb3728e0c0 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs @@ -113,6 +113,9 @@ private HttpWebRequest CreateRequest(Request messageRequest) request.Timeout = Timeout.Infinite; request.ReadWriteTimeout = Timeout.Infinite; + // Redirect is handled by the pipeline + request.AllowAutoRedirect = false; + // Don't disable the default proxy when there is no environment proxy configured if (_environmentProxy != null) { diff --git a/sdk/core/Azure.Core/src/Pipeline/Internal/RedirectPolicy.cs b/sdk/core/Azure.Core/src/Pipeline/Internal/RedirectPolicy.cs new file mode 100644 index 000000000000..405690333dce --- /dev/null +++ b/sdk/core/Azure.Core/src/Pipeline/Internal/RedirectPolicy.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using Azure.Core.Diagnostics; + +namespace Azure.Core.Pipeline +{ + internal sealed class RedirectPolicy : HttpPipelinePolicy + { + private readonly int _maxAutomaticRedirections; + + public static RedirectPolicy Shared { get; } = new RedirectPolicy(); + + private RedirectPolicy() + { + _maxAutomaticRedirections = 50; + } + + internal async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline, bool async) + { + if (async) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline); + } + + uint redirectCount = 0; + Uri? redirectUri; + + Request request = message.Request; + Response response = message.Response; + + while ((redirectUri = GetUriForRedirect(request, message.Response)) != null) + { + redirectCount++; + + if (redirectCount > _maxAutomaticRedirections) + { + // If we exceed the maximum number of redirects + // then just return the 3xx response. + if (AzureCoreEventSource.Singleton.IsEnabled()) + { + AzureCoreEventSource.Singleton.RequestRedirectCountExceeded(request.ClientRequestId, request.Uri.ToString(), redirectUri.ToString()); + } + + break; + } + + response.Dispose(); + + // Clear the authorization header. + request.Headers.Remove(HttpHeader.Names.Authorization); + + if (AzureCoreEventSource.Singleton.IsEnabled()) + { + AzureCoreEventSource.Singleton.RequestRedirect(request.ClientRequestId, request.Uri.ToString(), redirectUri.ToString(), response.Status); + } + + // Set up for the redirect + request.Uri.Reset(redirectUri); + if (RequestRequiresForceGet(response.Status, request.Method)) + { + request.Method = RequestMethod.Get; + request.Content = null; + } + + // Issue the redirected request. + if (async) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline); + } + + response = message.Response; + } + } + + private static Uri? GetUriForRedirect(Request request, Response response) + { + switch (response.Status) + { + case 301: + case 302: + case 303: + case 307: + case 300: + case 308: + break; + + default: + return null; + } + + if (!response.Headers.TryGetValue("Location", out string? locationString)) + { + return null; + } + + Uri location = new Uri(locationString); + Uri requestUri = request.Uri.ToUri(); + // Ensure the redirect location is an absolute URI. + if (!location.IsAbsoluteUri) + { + location = new Uri(requestUri, location); + } + + // Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a + // fragment should inherit the fragment from the original URI. + string requestFragment = requestUri.Fragment; + if (!string.IsNullOrEmpty(requestFragment)) + { + string redirectFragment = location.Fragment; + if (string.IsNullOrEmpty(redirectFragment)) + { + location = new UriBuilder(location) { Fragment = requestFragment }.Uri; + } + } + + // Disallow automatic redirection from secure to non-secure schemes + if (IsSupportedSecureScheme(requestUri.Scheme) && !IsSupportedSecureScheme(location.Scheme)) + { + if (AzureCoreEventSource.Singleton.IsEnabled()) + { + AzureCoreEventSource.Singleton.RequestRedirectBlocked(request.ClientRequestId, requestUri.ToString(), location.ToString()); + } + + return null; + } + + return location; + } + + private static bool RequestRequiresForceGet(int statusCode, RequestMethod requestMethod) + { + switch (statusCode) + { + case 301: + case 302: + case 300: + return requestMethod == RequestMethod.Post; + case 303: + return requestMethod != RequestMethod.Get && requestMethod != RequestMethod.Head; + default: + return false; + } + } + + internal static bool IsSupportedSecureScheme(string scheme) => + string.Equals(scheme, "https", StringComparison.OrdinalIgnoreCase) || IsSecureWebSocketScheme(scheme); + + internal static bool IsSecureWebSocketScheme(string scheme) => + string.Equals(scheme, "wss", StringComparison.OrdinalIgnoreCase); + + public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + return ProcessAsync(message, pipeline, true); + } + + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + ProcessAsync(message, pipeline, false).EnsureCompleted(); + } + } +} \ No newline at end of file diff --git a/sdk/core/Azure.Core/src/Request.cs b/sdk/core/Azure.Core/src/Request.cs index 7cec3b30cbf6..7c9f13d99420 100644 --- a/sdk/core/Azure.Core/src/Request.cs +++ b/sdk/core/Azure.Core/src/Request.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Net.Http; using Azure.Core.Pipeline; namespace Azure.Core diff --git a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs index bce88402cd9e..4a78b9e7e0d9 100644 --- a/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs +++ b/sdk/core/Azure.Core/tests/HttpPipelineFunctionalTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -624,6 +625,112 @@ public async Task SendMultipartformData() Assert.AreEqual(formData.Current.ContentDisposition, "form-data; name=LastName; filename=file_name.txt"); } + [Test] + public async Task HandlesRedirects() + { + HttpPipeline httpPipeline = HttpPipelineBuilder.Build(GetOptions()); + Uri testServerAddress = null; + using TestServer testServer = new TestServer( + context => + { + if (context.Request.Path.ToString().Contains("/redirected")) + { + context.Response.StatusCode = 200; + } + else + { + context.Response.StatusCode = 300; + context.Response.Headers.Add("Location", testServerAddress + "/redirected"); + } + return Task.CompletedTask; + }); + + testServerAddress = testServer.Address; + + using Request request = httpPipeline.CreateRequest(); + request.Method = RequestMethod.Get; + request.Uri.Reset(testServer.Address); + + using Response response = await ExecuteRequest(request, httpPipeline); + Assert.AreEqual(response.Status, 200); + } + + [Test] + public async Task PerRetryPolicyObservesRedirect() + { + List uris = new List(); + var options = GetOptions(); + var perRetryPolicy = new CallbackPolicy(message => uris.Add(message.Request.Uri.ToString())); + options.AddPolicy(perRetryPolicy, HttpPipelinePosition.PerRetry); + HttpPipeline httpPipeline = HttpPipelineBuilder.Build(options); + Uri testServerAddress = null; + using TestServer testServer = new TestServer( + context => + { + if (context.Request.Path.ToString().Contains("/redirected")) + { + context.Response.StatusCode = 200; + } + else + { + context.Response.StatusCode = 300; + context.Response.Headers.Add("Location", testServerAddress + "/redirected"); + } + return Task.CompletedTask; + }); + + testServerAddress = testServer.Address; + + using Request request = httpPipeline.CreateRequest(); + request.Method = RequestMethod.Get; + request.Uri.Reset(testServer.Address); + + using Response response = await ExecuteRequest(request, httpPipeline); + Assert.AreEqual(response.Status, 200); + Assert.AreEqual(2, uris.Count); + Assert.AreEqual(1, uris.Count(u => u.Contains("/redirected"))); + } + + [Test] + public async Task StopsOnMaxRedirects() + { + HttpPipeline httpPipeline = HttpPipelineBuilder.Build(GetOptions()); + Uri testServerAddress = null; + int count = 0; + using TestServer testServer = new TestServer( + context => + { + Interlocked.Increment(ref count); + context.Response.StatusCode = 300; + context.Response.Headers.Add("Location", testServerAddress + "/redirected"); + }); + + testServerAddress = testServer.Address; + + using Request request = httpPipeline.CreateRequest(); + request.Method = RequestMethod.Get; + request.Uri.Reset(testServer.Address); + + using Response response = await ExecuteRequest(request, httpPipeline); + Assert.AreEqual(300, response.Status); + Assert.AreEqual(51, count); + } + + private class CallbackPolicy : HttpPipelineSynchronousPolicy + { + private readonly Action _callback; + + public CallbackPolicy(Action callback) + { + _callback = callback; + } + + public override void OnSendingRequest(HttpMessage message) + { + base.OnSendingRequest(message); + _callback(message); + } + } private class TestOptions : ClientOptions { } diff --git a/sdk/core/Azure.Core/tests/RedirectPolicyTests.cs b/sdk/core/Azure.Core/tests/RedirectPolicyTests.cs new file mode 100644 index 000000000000..2c7855f44eed --- /dev/null +++ b/sdk/core/Azure.Core/tests/RedirectPolicyTests.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Threading.Tasks; +using Azure.Core.Diagnostics; +using Azure.Core.Pipeline; +using Azure.Core.TestFramework; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class RedirectPolicyTests : SyncAsyncPolicyTestBase + { + public RedirectPolicyTests(bool isAsync) : base(isAsync) + { + } + + [TestCaseSource(nameof(RedirectStatusCodes))] + [NonParallelizable] + public async Task UsesLocationResponseHeaderAsNewRequestUri(int code) + { + using var testListener = new TestEventListener(); + testListener.EnableEvents(AzureCoreEventSource.Singleton, EventLevel.Verbose); + + var fistResponse = new MockResponse(code).AddHeader("Location", "http://new.host/"); + var mockTransport = new MockTransport( + fistResponse, + new MockResponse(200)); + + var response = await SendGetRequest(mockTransport, RedirectPolicy.Shared); + + Assert.AreEqual(200, response.Status); + Assert.AreEqual(2, mockTransport.Requests.Count); + Assert.AreEqual("http://new.host/", mockTransport.Requests[1].Uri.ToString()); + Assert.True(fistResponse.IsDisposed); + + var e = testListener.SingleEventById(20); + + Assert.AreEqual(EventLevel.Verbose, e.Level); + Assert.AreEqual("RequestRedirect", e.EventName); + Assert.AreEqual(mockTransport.Requests[0].ClientRequestId, e.GetProperty("requestId")); + Assert.AreEqual("http://example.com/", e.GetProperty("from")); + Assert.AreEqual("http://new.host/", e.GetProperty("to")); + Assert.AreEqual(code, e.GetProperty("status")); + } + + [TestCaseSource(nameof(RedirectStatusCodesOldMethodsNewMethods))] + public async Task ChangesMethodWhenRequired(int code, string oldMethod, string newMethod) + { + var mockTransport = new MockTransport( + new MockResponse(code).AddHeader("Location", "http://new.host/"), + new MockResponse(200)); + + await SendRequestAsync(mockTransport, request => + { + request.Uri.Reset(new Uri("http://example.com/")); + request.Method = new RequestMethod(oldMethod); + }, RedirectPolicy.Shared); + + Assert.AreEqual(2, mockTransport.Requests.Count); + Assert.AreEqual("http://new.host/", mockTransport.Requests[1].Uri.ToString()); + Assert.AreEqual(newMethod, mockTransport.Requests[1].Method.ToString()); + } + + [Test] + [NonParallelizable] + public async Task ReturnsOnMaxRedirects() + { + using var testListener = new TestEventListener(); + testListener.EnableEvents(AzureCoreEventSource.Singleton, EventLevel.Verbose); + + var mockTransport = new MockTransport(_ => + new MockResponse(300).AddHeader("Location", "http://new.host/")); + + var response = await SendGetRequest(mockTransport, RedirectPolicy.Shared); + + Assert.AreEqual(300, response.Status); + Assert.AreEqual(51, mockTransport.Requests.Count); + Assert.AreEqual("http://new.host/", mockTransport.Requests[1].Uri.ToString()); + + var e = testListener.SingleEventById(22); + + Assert.AreEqual(EventLevel.Warning, e.Level); + Assert.AreEqual("RequestRedirectCountExceeded", e.EventName); + Assert.AreEqual(mockTransport.Requests[0].ClientRequestId, e.GetProperty("requestId")); + Assert.AreEqual("http://new.host/", e.GetProperty("from")); + Assert.AreEqual("http://new.host/", e.GetProperty("to")); + } + + [Test] + [NonParallelizable] + public async Task BlocksUnsafeRedirect() + { + using var testListener = new TestEventListener(); + testListener.EnableEvents(AzureCoreEventSource.Singleton, EventLevel.Verbose); + + var mockTransport = new MockTransport(_ => + new MockResponse(300).AddHeader("Location", "http://new.host/")); + + var response = await SendRequestAsync(mockTransport, request => + { + request.Uri.Reset(new Uri("https://example.com/")); + }, RedirectPolicy.Shared); + + Assert.AreEqual(300, response.Status); + Assert.AreEqual(1, mockTransport.Requests.Count); + + var e = testListener.SingleEventById(21); + + Assert.AreEqual(EventLevel.Warning, e.Level); + Assert.AreEqual("RequestRedirectBlocked", e.EventName); + Assert.AreEqual(mockTransport.Requests[0].ClientRequestId, e.GetProperty("requestId")); + Assert.AreEqual("https://example.com/", e.GetProperty("from")); + Assert.AreEqual("http://new.host/", e.GetProperty("to")); + } + + [Test] + public async Task RemovesAuthHeader() + { + var mockTransport = new MockTransport( + new MockResponse(300) + .AddHeader("Location", "http://new.host/") + .AddHeader("Authorization", "secret value"), + new MockResponse(200)); + + var response = await SendGetRequest(mockTransport, RedirectPolicy.Shared); + + Assert.AreEqual(200, response.Status); + Assert.AreEqual(2, mockTransport.Requests.Count); + Assert.False(mockTransport.Requests[1].Headers.Contains("Authorization")); + } + + public static readonly object[][] RedirectStatusCodes = { + new object[] { 300 }, + new object[] { 301 }, + new object[] { 302 }, + new object[] { 303 }, + new object[] { 307 }, + new object[] { 308 } + }; + + public static readonly object[][] RedirectStatusCodesOldMethodsNewMethods = { + new object[] { 300, "GET", "GET" }, + new object[] { 300, "POST", "GET" }, + new object[] { 300, "HEAD", "HEAD" }, + + new object[] { 301, "GET", "GET" }, + new object[] { 301, "POST", "GET" }, + new object[] { 301, "HEAD", "HEAD" }, + + new object[] { 302, "GET", "GET" }, + new object[] { 302, "POST", "GET" }, + new object[] { 302, "HEAD", "HEAD" }, + + new object[] { 303, "GET", "GET" }, + new object[] { 303, "POST", "GET" }, + new object[] { 303, "HEAD", "HEAD" }, + + new object[] { 307, "GET", "GET" }, + new object[] { 307, "POST", "POST" }, + new object[] { 307, "HEAD", "HEAD" }, + + new object[] { 308, "GET", "GET" }, + new object[] { 308, "POST", "POST" }, + new object[] { 308, "HEAD", "HEAD" }, + }; + } +} \ No newline at end of file