diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/RedirectPolicy.java b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/RedirectPolicy.java index 4fc4b1e03dfe..ff74215558e0 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/RedirectPolicy.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/RedirectPolicy.java @@ -62,7 +62,13 @@ private Mono attemptRedirect(final HttpPipelineCallContext context if (redirectStrategy.shouldAttemptRedirect(context, httpResponse, redirectAttempt, attemptedRedirectUrls)) { HttpRequest redirectRequestCopy = redirectStrategy.createRedirectRequest(httpResponse); - return httpResponse.getBody() + + // Clear the authorization header to avoid the client to be redirected to an untrusted third party server + // causing it to leak your authorization token to. + httpResponse.getHeaders().remove("Authorization"); + + return httpResponse + .getBody() .ignoreElements() .then(attemptRedirect(context, next, redirectRequestCopy, redirectAttempt + 1, attemptedRedirectUrls)); } else { diff --git a/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/RedirectPolicyTest.java b/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/RedirectPolicyTest.java index 872ed9980436..81a99f854cae 100644 --- a/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/RedirectPolicyTest.java +++ b/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/RedirectPolicyTest.java @@ -13,6 +13,8 @@ import com.azure.core.http.MockHttpResponse; import com.azure.core.http.clients.NoOpHttpClient; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import java.net.MalformedURLException; @@ -25,6 +27,7 @@ import java.util.function.Function; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; public class RedirectPolicyTest { @@ -53,14 +56,16 @@ public Mono send(HttpRequest request) { assertEquals(308, response.getStatusCode()); } - @Test - public void defaultRedirectWhen308() throws Exception { + @ParameterizedTest + @ValueSource(ints = {308, 307, 301, 302}) + public void defaultRedirectExpectedStatusCodes(int statusCode) throws Exception { RecordingHttpClient httpClient = new RecordingHttpClient(request -> { if (request.getUrl().toString().equals("http://localhost/")) { Map headers = new HashMap<>(); headers.put("Location", "http://redirecthost/"); + headers.put("Authorization", "12345"); HttpHeaders httpHeader = new HttpHeaders(headers); - return Mono.just(new MockHttpResponse(request, 308, httpHeader)); + return Mono.just(new MockHttpResponse(request, statusCode, httpHeader)); } else { return Mono.just(new MockHttpResponse(request, 200)); } @@ -74,8 +79,8 @@ public void defaultRedirectWhen308() throws Exception { HttpResponse response = pipeline.send(new HttpRequest(HttpMethod.GET, new URL("http://localhost/"))).block(); - // assertEquals(2, httpClient.getCount()); assertEquals(200, response.getStatusCode()); + assertNull(response.getHeaders().getValue("Authorization")); } @Test @@ -326,6 +331,32 @@ public Mono send(HttpRequest request) { assertEquals(401, response.getStatusCode()); } + @Test + public void defaultRedirectAuthorizationHeaderCleared() throws Exception { + RecordingHttpClient httpClient = new RecordingHttpClient(request -> { + if (request.getUrl().toString().equals("http://localhost/")) { + Map headers = new HashMap<>(); + headers.put("Location", "http://redirecthost/"); + headers.put("Authorization", "12345"); + HttpHeaders httpHeader = new HttpHeaders(headers); + return Mono.just(new MockHttpResponse(request, 308, httpHeader)); + } else { + return Mono.just(new MockHttpResponse(request, 200)); + } + }); + + HttpPipeline pipeline = new HttpPipelineBuilder() + .httpClient(httpClient) + .policies(new RedirectPolicy()) + .build(); + + HttpResponse response = pipeline.send(new HttpRequest(HttpMethod.GET, + new URL("http://localhost/"))).block(); + + assertEquals(200, response.getStatusCode()); + assertNull(response.getHeaders().getValue("Authorization")); + } + static class RecordingHttpClient implements HttpClient { private final AtomicInteger count = new AtomicInteger();