diff --git a/eng/code-quality-reports/src/main/resources/checkstyle/checkstyle-suppressions.xml b/eng/code-quality-reports/src/main/resources/checkstyle/checkstyle-suppressions.xml index 5711e184e352..e9199779c633 100755 --- a/eng/code-quality-reports/src/main/resources/checkstyle/checkstyle-suppressions.xml +++ b/eng/code-quality-reports/src/main/resources/checkstyle/checkstyle-suppressions.xml @@ -605,4 +605,8 @@ the main ServiceBusClientBuilder. --> + + + diff --git a/sdk/communication/azure-communication-callingserver/src/main/java/com/azure/communication/callingserver/CallingServerClientBuilder.java b/sdk/communication/azure-communication-callingserver/src/main/java/com/azure/communication/callingserver/CallingServerClientBuilder.java index 498d493ce4b0..35fc1fe8ba8d 100644 --- a/sdk/communication/azure-communication-callingserver/src/main/java/com/azure/communication/callingserver/CallingServerClientBuilder.java +++ b/sdk/communication/azure-communication-callingserver/src/main/java/com/azure/communication/callingserver/CallingServerClientBuilder.java @@ -5,6 +5,7 @@ import com.azure.communication.callingserver.implementation.AzureCommunicationCallingServerServiceImpl; import com.azure.communication.callingserver.implementation.AzureCommunicationCallingServerServiceImplBuilder; +import com.azure.communication.common.implementation.RedirectPolicy; import com.azure.communication.common.implementation.CommunicationConnectionString; import com.azure.communication.common.implementation.HmacAuthenticationPolicy; import com.azure.core.annotation.ServiceClientBuilder; @@ -317,6 +318,7 @@ private HttpPipeline createHttpPipeline(HttpClient httpClient) { policyList.add(new UserAgentPolicy(applicationId, clientName, clientVersion, configuration)); policyList.add(new RequestIdPolicy()); policyList.add((retryPolicy == null) ? new RetryPolicy() : retryPolicy); + policyList.add(new RedirectPolicy()); policyList.add(createHttpPipelineAuthPolicy()); policyList.add(new CookiePolicy()); diff --git a/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncLiveTests.java b/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncLiveTests.java index 66c78edcb63d..77c0ed58cda1 100644 --- a/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncLiveTests.java +++ b/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncLiveTests.java @@ -6,13 +6,12 @@ import com.azure.communication.callingserver.models.CallingServerErrorException; import com.azure.communication.callingserver.models.ParallelDownloadOptions; import com.azure.core.http.HttpClient; -import com.azure.core.http.rest.Response; -import com.azure.core.util.FluxUtil; import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousFileChannel; @@ -22,9 +21,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; @@ -44,11 +41,7 @@ public void downloadMetadataAsync(HttpClient httpClient) { CallingServerAsyncClient conversationAsyncClient = setupAsyncClient(builder, "downloadMetadataAsync"); try { - Flux content = conversationAsyncClient.downloadStream(METADATA_URL); - byte[] contentBytes = FluxUtil.collectBytesInByteBufferStream(content).block(); - assertThat(contentBytes, is(notNullValue())); - String metadata = new String(contentBytes, StandardCharsets.UTF_8); - assertThat(metadata.contains("0-eus-d2-3cca2175891f21c6c9a5975a12c0141c"), is(true)); + validateMetadata(conversationAsyncClient.downloadStream(METADATA_URL)); } catch (Exception e) { System.out.println("Error: " + e.getMessage()); throw e; @@ -66,11 +59,7 @@ public void downloadMetadataRetryingAsync(HttpClient httpClient) { CallingServerAsyncClient conversationAsyncClient = setupAsyncClient(builder, "downloadMetadataRetryingAsync"); try { - Flux content = conversationAsyncClient.downloadStream(METADATA_URL); - byte[] contentBytes = FluxUtil.collectBytesInByteBufferStream(content).block(); - assertThat(contentBytes, is(notNullValue())); - String metadata = new String(contentBytes, StandardCharsets.UTF_8); - assertThat(metadata.contains("0-eus-d2-3cca2175891f21c6c9a5975a12c0141c"), is(true)); + validateMetadata(conversationAsyncClient.downloadStream(METADATA_URL)); } catch (Exception e) { System.out.println("Error: " + e.getMessage()); throw e; @@ -88,11 +77,16 @@ public void downloadVideoAsync(HttpClient httpClient) { CallingServerAsyncClient conversationAsyncClient = setupAsyncClient(builder, "downloadVideoAsync"); try { - Response> response = conversationAsyncClient.downloadStreamWithResponse(VIDEO_URL, null).block(); - assertThat(response, is(notNullValue())); - byte[] contentBytes = FluxUtil.collectBytesInByteBufferStream(response.getValue()).block(); - assertThat(contentBytes, is(notNullValue())); - assertThat(Integer.parseInt(response.getHeaders().getValue("Content-Length")), is(equalTo(contentBytes.length))); + StepVerifier.create(conversationAsyncClient.downloadStreamWithResponse(VIDEO_URL, null)) + .consumeNextWith(response -> { + StepVerifier.create(response.getValue()) + .consumeNextWith(byteBuffer -> { + assertThat(Integer.parseInt(response.getHeaders().getValue("Content-Length")), + is(equalTo(byteBuffer.array().length))); + }) + .verifyComplete(); + }) + .verifyComplete(); } catch (Exception e) { System.out.println("Error: " + e.getMessage()); throw e; @@ -178,14 +172,31 @@ public void downloadToFileRetryingAsync(HttpClient httpClient) { public void downloadContent404Async(HttpClient httpClient) { CallingServerClientBuilder builder = getConversationClientUsingConnectionString(httpClient); CallingServerAsyncClient conversationAsyncClient = setupAsyncClient(builder, "downloadContent404Async"); - Response> response = conversationAsyncClient - .downloadStreamWithResponse(CONTENT_URL_404, null).block(); - assertThat(response, is(notNullValue())); - assertThat(response.getStatusCode(), is(equalTo(404))); - assertThrows(CallingServerErrorException.class, - () -> FluxUtil.collectBytesInByteBufferStream(response.getValue()).block()); + StepVerifier.create(conversationAsyncClient.downloadStreamWithResponse(CONTENT_URL_404, null)) + .consumeNextWith(response -> { + assertThat(response.getStatusCode(), is(equalTo(404))); + StepVerifier.create(response.getValue()).verifyError(CallingServerErrorException.class); + }) + .verifyComplete(); } + @ParameterizedTest + @MethodSource("com.azure.core.test.TestBase#getHttpClients") + @DisabledIfEnvironmentVariable( + named = "SKIP_LIVE_TEST", + matches = "(?i)(true)", + disabledReason = "Requires human intervention") + public void downloadMetadataWithRedirectAsync(HttpClient httpClient) { + CallingServerClientBuilder builder = getConversationClientUsingConnectionString(httpClient); + CallingServerAsyncClient conversationAsyncClient = setupAsyncClient(builder, "downloadMetadataAsync"); + + try { + validateMetadata(conversationAsyncClient.downloadStream(METADATA_URL)); + } catch (Exception e) { + System.out.println("Error: " + e.getMessage()); + throw e; + } + } private CallingServerAsyncClient setupAsyncClient(CallingServerClientBuilder builder, String testName) { return addLoggingPolicy(builder, testName).buildAsyncClient(); @@ -194,4 +205,13 @@ private CallingServerAsyncClient setupAsyncClient(CallingServerClientBuilder bui protected CallingServerClientBuilder addLoggingPolicy(CallingServerClientBuilder builder, String testName) { return builder.addPolicy((context, next) -> logHeaders(testName, next)); } + + private void validateMetadata(Flux metadataByteBuffer) { + StepVerifier.create(metadataByteBuffer) + .consumeNextWith(byteBuffer -> { + String metadata = new String(byteBuffer.array(), StandardCharsets.UTF_8); + assertThat(metadata.contains("0-eus-d2-3cca2175891f21c6c9a5975a12c0141c"), is(true)); + }) + .verifyComplete(); + } } diff --git a/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncUnitTests.java b/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncUnitTests.java index d66736dc646a..d0ffac1090b7 100644 --- a/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncUnitTests.java +++ b/sdk/communication/azure-communication-callingserver/src/test/java/com/azure/communication/callingserver/DownloadContentAsyncUnitTests.java @@ -4,7 +4,6 @@ package com.azure.communication.callingserver; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.BufferedReader; @@ -16,86 +15,102 @@ import java.nio.file.FileSystems; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; import java.util.UUID; import java.util.AbstractMap.SimpleEntry; import com.azure.communication.callingserver.models.ParallelDownloadOptions; import com.azure.core.http.HttpRange; -import com.azure.core.http.rest.Response; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; public class DownloadContentAsyncUnitTests { + + private static final String CONTENTS = "VideoContents"; + private CallingServerAsyncClient callingServerClient; + + @BeforeEach + public void setup() { + callingServerClient = + CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList<>( + Collections.singletonList( + new SimpleEntry<>(CallingServerResponseMocker.generateDownloadResult(CONTENTS), 200) + ))); + } @Test - public void downloadStream() throws IOException { - String contents = "VideoContents"; - CallingServerAsyncClient callingServerClient = CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList>( - Arrays.asList( - new SimpleEntry(CallingServerResponseMocker.generateDownloadResult(contents), 200) - ))); - - Flux fluxByteBuffer = callingServerClient.downloadStream("https://url.com", new HttpRange(contents.length())); - - String resultContents = new String(fluxByteBuffer.next().block().array(), StandardCharsets.UTF_8); - assertEquals("VideoContents", resultContents); + public void downloadStream() { + StepVerifier.create( + callingServerClient.downloadStream( + "https://url.com", + new HttpRange(CONTENTS.length())) + ).consumeNextWith(byteBuffer -> { + String resultContents = new String(byteBuffer.array(), StandardCharsets.UTF_8); + assertEquals(CONTENTS, resultContents); + }).verifyComplete(); } @Test - public void downloadStreamWithResponse() throws IOException { - String contents = "VideoContents"; - CallingServerAsyncClient callingServerClient = CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList>( - Arrays.asList( - new SimpleEntry(CallingServerResponseMocker.generateDownloadResult(contents), 200) - ))); - - Response> fluxByteBufferResponse = callingServerClient.downloadStreamWithResponse("https://url.com", new HttpRange(contents.length())).block(); - assertEquals(200, fluxByteBufferResponse.getStatusCode()); - Flux fluxByteBuffer = fluxByteBufferResponse.getValue(); - String resultContents = new String(fluxByteBuffer.next().block().array(), StandardCharsets.UTF_8); - assertEquals("VideoContents", resultContents); + public void downloadStreamWithResponse() { + StepVerifier.create( + callingServerClient.downloadStreamWithResponse( + "https://url.com", + new HttpRange(CONTENTS.length())) + ).consumeNextWith(response -> { + assertEquals(200, response.getStatusCode()); + verifyContents(response.getValue()); + }).verifyComplete(); } - + @Test - public void downloadStreamWithResponseThrowException() throws IOException { - String contents = "VideoContents"; - CallingServerAsyncClient callingServerClient = CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList>( - Arrays.asList( - new SimpleEntry("", 416) - ))); - - Response> fluxByteBufferResponse = callingServerClient.downloadStreamWithResponse("https://url.com", new HttpRange(contents.length())).block(); - Flux fluxByteBuffer = fluxByteBufferResponse.getValue(); - assertThrows(NullPointerException.class, () -> fluxByteBuffer.next().block()); + public void downloadStreamWithResponseThrowException() { + callingServerClient = + CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList<>( + Collections.singletonList( + new SimpleEntry<>("", 416) + ))); + + StepVerifier.create( + callingServerClient.downloadStreamWithResponse("https://url.com", new HttpRange(CONTENTS.length())) + ).consumeNextWith(response -> { + StepVerifier.create(response.getValue()) + .verifyError(NullPointerException.class); + }); } @Test public void downloadToWithResponse() throws IOException { - String contents = "VideoContents"; - CallingServerAsyncClient callingServerClient = CallingServerResponseMocker.getCallingServerAsyncClient(new ArrayList>( - Arrays.asList( - new SimpleEntry(CallingServerResponseMocker.generateDownloadResult(contents), 200) - ))); String fileName = "./" + UUID.randomUUID().toString().replace("-", "") + ".txt"; Path path = FileSystems.getDefault().getPath(fileName); ParallelDownloadOptions options = new ParallelDownloadOptions().setMaxConcurrency(1).setBlockSize(512L); File file = null; try { - Response response = callingServerClient.downloadToWithResponse("https://url.com", path, options, true).block(); - assertEquals(200, response.getStatusCode()); - + StepVerifier.create(callingServerClient.downloadToWithResponse("https://url.com", path, options, true)) + .consumeNextWith(response -> { + assertEquals(200, response.getStatusCode()); + }).verifyComplete(); + file = path.toFile(); assertTrue(file.exists(), "file does not exist"); - BufferedReader br = new BufferedReader(new FileReader(file)); - assertEquals("VideoContents", br.readLine()); + BufferedReader br = new BufferedReader(new FileReader(file)); + assertEquals(CONTENTS, br.readLine()); br.close(); } finally { if (file != null && file.exists()) { - file.delete(); + file.delete(); } } } + + private void verifyContents(Flux response) { + StepVerifier.create(response) + .consumeNextWith(byteBuffer -> { + String resultContents = new String(byteBuffer.array(), StandardCharsets.UTF_8); + assertEquals(CONTENTS, resultContents); + }).verifyComplete(); + } } diff --git a/sdk/communication/azure-communication-callingserver/src/test/resources/session-records/DownloadContentAsyncLiveTests.downloadMetadataWithRedirectAsync[1].json b/sdk/communication/azure-communication-callingserver/src/test/resources/session-records/DownloadContentAsyncLiveTests.downloadMetadataWithRedirectAsync[1].json new file mode 100644 index 000000000000..203dfc1dee25 --- /dev/null +++ b/sdk/communication/azure-communication-callingserver/src/test/resources/session-records/DownloadContentAsyncLiveTests.downloadMetadataWithRedirectAsync[1].json @@ -0,0 +1,41 @@ +{ + "networkCallRecords" : [ { + "Method" : "GET", + "Uri" : "https://REDACTED.asm.skype.com/v1/objects/0-eus-d2-3cca2175891f21c6c9a5975a12c0141c/content/acsmetadata", + "Headers" : { + "User-Agent" : "azsdk-java-azure-communication-callingserver/1.0.0-beta.3 (11.0.11; Windows 10; 10.0)" + }, + "Response" : { + "content-length" : "0", + "Strict-Transport-Security" : "max-age=31536000; includeSubDomains", + "Cache-Control" : "no-cache, max-age=0, s-maxage=0, private", + "Server" : "Microsoft-HTTPAPI/2.0", + "retry-after" : "0", + "StatusCode" : "302", + "Body" : "", + "Date" : "Thu, 03 Jun 2021 00:07:30 GMT", + "Location": "https://REDACTED.as.asm.skype.com/v1/objects/0-eus-d2-3cca2175891f21c6c9a5975a12c0141c/content/acsmetadata" + }, + "Exception" : null + },{ + "Method" : "GET", + "Uri" : "https://REDACTED.as.asm.skype.com/v1/objects/0-eus-d2-3cca2175891f21c6c9a5975a12c0141c/content/acsmetadata", + "Headers" : { + "User-Agent" : "azsdk-java-azure-communication-callingserver/1.0.0-beta.3 (11.0.11; Windows 10; 10.0)" + }, + "Response" : { + "content-length" : "957", + "Strict-Transport-Security" : "max-age=31536000; includeSubDomains", + "Cache-Control" : "no-cache, max-age=0, s-maxage=0, private", + "Server" : "Microsoft-HTTPAPI/2.0", + "Content-Range" : "bytes 0-956/957", + "retry-after" : "0", + "StatusCode" : "206", + "Body" : "ew0KICAicmVzb3VyY2VJZCI6ICI2MzFmYThkOC1hYWI1LTRhYzUtOGUxNS0yNjFhYTI1OTA3NTAiLA0KICAiY2FsbElkIjogImEzMjdhOGU0LTRjMjQtNGM4NC05ZmUyLTA5ZmZlNjIzYzg1OCIsDQogICJjaHVua0RvY3VtZW50SWQiOiAiMC1ldXMtZDItM2NjYTIxNzU4OTFmMjFjNmM5YTU5NzVhMTJjMDE0MWMiLA0KICAiY2h1bmtJbmRleCI6IDAsDQogICJjaHVua1N0YXJ0VGltZSI6ICIyMDIxLTA2LTAyVDIxOjQ1OjQxLjY0OTQyMjRaIiwNCiAgImNodW5rRHVyYXRpb24iOiA1NTgwLjAsDQogICJwYXVzZVJlc3VtZUludGVydmFscyI6IFtdLA0KICAicmVjb3JkaW5nSW5mbyI6IHsNCiAgICAiY29udGVudFR5cGUiOiAibWl4ZWQiLA0KICAgICJjaGFubmVsVHlwZSI6ICJhdWRpb1ZpZGVvIiwNCiAgICAiZm9ybWF0IjogIm1wNCIsDQogICAgImF1ZGlvQ29uZmlndXJhdGlvbiI6IHsNCiAgICAgICJzYW1wbGVSYXRlIjogMTYwMDAsDQogICAgICAiYml0UmF0ZSI6IDEyODAwMCwNCiAgICAgICJjaGFubmVscyI6IDENCiAgICB9LA0KICAgICJ2aWRlb0NvbmZpZ3VyYXRpb24iOiB7DQogICAgICAibG9uZ2VyU2lkZUxlbmd0aCI6IDE5MjAsDQogICAgICAic2hvcnRlclNpZGVMZW5ndGgiOiAxMDgwLA0KICAgICAgImZyYW1lcmF0ZSI6IDgsDQogICAgICAiYml0UmF0ZSI6IDEwMDAwMDANCiAgICB9DQogIH0sDQogICJwYXJ0aWNpcGFudHMiOiBbDQogICAgew0KICAgICAgInBhcnRpY2lwYW50SWQiOiAiODphY3M6NjMxZmE4ZDgtYWFiNS00YWM1LThlMTUtMjYxYWEyNTkwNzUwXzAwMDAwMDBhLTZlOGItYjMzYy1kZWZkLThiM2EwZDAwNTFjYiINCiAgICB9LA0KICAgIHsNCiAgICAgICJwYXJ0aWNpcGFudElkIjogIjg6YWNzOjYzMWZhOGQ4LWFhYjUtNGFjNS04ZTE1LTI2MWFhMjU5MDc1MF8wMDAwMDAwYS02ZThiLWNhMTctZGVmZC04YjNhMGQwMDUxY2QiDQogICAgfQ0KICBdDQp9", + "Date" : "Thu, 03 Jun 2021 00:07:30 GMT", + "Content-Type" : "application/octet-stream" + }, + "Exception" : null + } ], + "variables" : [ ] +} diff --git a/sdk/communication/azure-communication-common/pom.xml b/sdk/communication/azure-communication-common/pom.xml index 40b743bedbb1..2adc6529bf5d 100644 --- a/sdk/communication/azure-communication-common/pom.xml +++ b/sdk/communication/azure-communication-common/pom.xml @@ -84,5 +84,40 @@ 3.4.7 test + + org.mockito + mockito-core + 3.9.0 + test + + + org.hamcrest + hamcrest-all + 1.3 + test + + + + + java-lts + + [11,) + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.0.0-M3 + + + --add-opens com.azure.communication.common/com.azure.communication.common.implementation=ALL-UNNAMED + + + + + + + diff --git a/sdk/communication/azure-communication-common/src/main/java/com/azure/communication/common/implementation/RedirectPolicy.java b/sdk/communication/azure-communication-common/src/main/java/com/azure/communication/common/implementation/RedirectPolicy.java new file mode 100644 index 000000000000..897851e3f8a8 --- /dev/null +++ b/sdk/communication/azure-communication-common/src/main/java/com/azure/communication/common/implementation/RedirectPolicy.java @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.communication.common.implementation; + +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.http.policy.HttpPipelinePolicy; +import reactor.core.publisher.Mono; + +import java.util.HashSet; +import java.util.Set; + +/** + * HttpPipelinePolicy to redirect requests when 302 message is received to the new location marked by the + * Location header. + */ +public final class RedirectPolicy implements HttpPipelinePolicy { + private static final int MAX_REDIRECTS = 10; + private static final String LOCATION_HEADER_NAME = "Location"; + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + return attemptRedirection(context, next, 0, new HashSet<>()); + } + + private Mono attemptRedirection(HttpPipelineCallContext context, HttpPipelineNextPolicy next, + int redirectNumber, Set locations) { + return next.clone().process().flatMap(httpResponse -> { + if (shouldRedirect(httpResponse, redirectNumber, locations)) { + String newLocation = httpResponse.getHeaderValue(LOCATION_HEADER_NAME); + locations.add(newLocation); + + HttpRequest newRequest = context.getHttpRequest().copy(); + newRequest.setUrl(newLocation); + context.setHttpRequest(newRequest); + + return attemptRedirection(context, next, redirectNumber + 1, locations); + } + return Mono.just(httpResponse); + }); + } + + private boolean shouldRedirect(HttpResponse response, int redirectNumber, Set locations) { + return response.getStatusCode() == 302 + && !locations.contains(response.getHeaderValue(LOCATION_HEADER_NAME)) + && redirectNumber < MAX_REDIRECTS; + } +} diff --git a/sdk/communication/azure-communication-common/src/test/java/com/azure/communication/common/implementation/RedirectPolicyTests.java b/sdk/communication/azure-communication-common/src/test/java/com/azure/communication/common/implementation/RedirectPolicyTests.java new file mode 100644 index 000000000000..8e0e57247d08 --- /dev/null +++ b/sdk/communication/azure-communication-common/src/test/java/com/azure/communication/common/implementation/RedirectPolicyTests.java @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.communication.common.implementation; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipeline; +import com.azure.core.http.HttpPipelineBuilder; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.Context; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.net.MalformedURLException; +import java.net.URL; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.equalTo; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +public class RedirectPolicyTests { + static final String ORIGINAL_LOCATION = "https://localhost.com"; + static final String REDIRECT_LOCATION = "https://localhost-2.com"; + static final RedirectPolicy REDIRECT_POLICY = new RedirectPolicy(); + private HttpRequest request; + private HttpPipeline pipeline; + + @Mock + HttpClient httpClient; + + @Mock + HttpResponse response200; + + @Mock + HttpResponse response302; + + @BeforeEach + public void setup() throws MalformedURLException { + MockitoAnnotations.openMocks(this); + when(response200.getStatusCode()).thenReturn(200); + when(response302.getStatusCode()).thenReturn(302); + when(response302.getHeaderValue("Location")).thenReturn(REDIRECT_LOCATION); + + pipeline = new HttpPipelineBuilder() + .httpClient(httpClient) + .policies(REDIRECT_POLICY) + .build(); + + request = new HttpRequest(HttpMethod.GET, new URL(ORIGINAL_LOCATION)); + } + + @Test + public void noRedirectionPerformedTest() { + setSuccessMockResponse(); + + verifyCorrectness(response200); + } + + @Test + public void redirectionPerformedTest() { + setRedirectSuccessMockResponses(); + + verifyCorrectness(response200); + } + + @Test + public void sameLocationUsedShortCircuitTest() { + setRedirectRedirectMockResponse(); + + verifyCorrectness(response302); + } + + @Test + public void sameLocationUsedInDifferentRequestsSuccessTest() { + for (int i = 0; i < 3; i++) { + setRedirectSuccessMockResponses(); + verifyCorrectness(response200); + } + } + + private void setSuccessMockResponse() { + doAnswer(invocation -> { + HttpRequest request = invocation.getArgument(0); + assertThat(request.getUrl().toString(), is(equalTo(ORIGINAL_LOCATION))); + return Mono.just(response200); + }).when(httpClient).send(any(HttpRequest.class), any(Context.class)); + } + + private void setRedirectSuccessMockResponses() { + doAnswer(invocation -> { + HttpRequest request = invocation.getArgument(0); + assertThat(request.getUrl().toString(), is(equalTo(ORIGINAL_LOCATION))); + return Mono.just(response302); + }) + .doAnswer(invocation -> { + HttpRequest request = invocation.getArgument(0); + assertThat(request.getUrl().toString(), is(equalTo(REDIRECT_LOCATION))); + return Mono.just(response200); + }) + .when(httpClient).send(any(HttpRequest.class), any(Context.class)); + } + + private void setRedirectRedirectMockResponse() { + doAnswer(invocation -> Mono.just(response302)) + .doAnswer(invocation -> Mono.just(response302)) + .when(httpClient).send(any(HttpRequest.class), any(Context.class)); + } + + private void verifyCorrectness(HttpResponse expectedResponse) { + StepVerifier.create(pipeline.send(request)).expectNext(expectedResponse).verifyComplete(); + } +}