diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/http/PolicyDecoratingHttpClientTests.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/http/PolicyDecoratingHttpClientTests.java index e014322e99e4..81052b047516 100644 --- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/http/PolicyDecoratingHttpClientTests.java +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/http/PolicyDecoratingHttpClientTests.java @@ -27,6 +27,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; class PolicyDecoratingHttpClientTests { @@ -80,6 +81,114 @@ void sendDelegatesToUnderlyingClient() throws Exception { assertEquals(request, recordingClient.getLastRequest()); } + @Test + void asyncSendAppliesPolicies() throws Exception { + RecordingHttpClient recordingClient = new RecordingHttpClient(); + HttpPipelinePolicy perCallPolicy + = new HeaderAppendingPolicy(PER_CALL_HEADER, "async-one", HttpPipelinePosition.PER_CALL); + HttpPipelinePolicy perRetryPolicy + = new HeaderAppendingPolicy(PER_RETRY_HEADER, "async-two", HttpPipelinePosition.PER_RETRY); + + PolicyDecoratingHttpClient client + = new PolicyDecoratingHttpClient(recordingClient, Arrays.asList(perCallPolicy, perRetryPolicy)); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + HttpResponse response = client.send(request).block(); + + assertNotNull(response); + HttpRequest sentRequest = recordingClient.getLastRequest(); + assertNotNull(sentRequest); + HttpHeaders headers = sentRequest.getHeaders(); + assertEquals("async-one", headers.getValue(PER_CALL_HEADER)); + assertEquals("async-two", headers.getValue(PER_RETRY_HEADER)); + assertEquals(1, recordingClient.getSendCount()); + } + + @Test + void asyncSendWithContextAppliesPolicies() throws Exception { + RecordingHttpClient recordingClient = new RecordingHttpClient(); + HttpPipelinePolicy perCallPolicy + = new HeaderAppendingPolicy(PER_CALL_HEADER, "context-one", HttpPipelinePosition.PER_CALL); + HttpPipelinePolicy perRetryPolicy + = new HeaderAppendingPolicy(PER_RETRY_HEADER, "context-two", HttpPipelinePosition.PER_RETRY); + + PolicyDecoratingHttpClient client + = new PolicyDecoratingHttpClient(recordingClient, Arrays.asList(perCallPolicy, perRetryPolicy)); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + HttpResponse response = client.send(request, Context.NONE).block(); + + assertNotNull(response); + HttpRequest sentRequest = recordingClient.getLastRequest(); + assertNotNull(sentRequest); + HttpHeaders headers = sentRequest.getHeaders(); + assertEquals("context-one", headers.getValue(PER_CALL_HEADER)); + assertEquals("context-two", headers.getValue(PER_RETRY_HEADER)); + assertEquals(1, recordingClient.getSendCount()); + } + + @Test + void policyErrorPropagatesInAsyncSend() throws Exception { + RecordingHttpClient recordingClient = new RecordingHttpClient(); + RuntimeException policyException = new RuntimeException("Policy error"); + HttpPipelinePolicy failingPolicy = new FailingPolicy(policyException); + + PolicyDecoratingHttpClient client + = new PolicyDecoratingHttpClient(recordingClient, Collections.singletonList(failingPolicy)); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + RuntimeException thrown = assertThrows(RuntimeException.class, () -> client.send(request).block()); + assertEquals("Policy error", thrown.getMessage()); + assertEquals(0, recordingClient.getSendCount()); + } + + @Test + void underlyingClientErrorPropagatesInAsyncSend() throws Exception { + RuntimeException clientException = new RuntimeException("Client error"); + FailingHttpClient failingClient = new FailingHttpClient(clientException); + + PolicyDecoratingHttpClient client = new PolicyDecoratingHttpClient(failingClient, Collections.emptyList()); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + RuntimeException thrown = assertThrows(RuntimeException.class, () -> client.send(request).block()); + assertEquals("Client error", thrown.getMessage()); + assertTrue(failingClient.wasCalled()); + } + + @Test + void policyErrorPropagatesInSyncSend() throws Exception { + RecordingHttpClient recordingClient = new RecordingHttpClient(); + RuntimeException policyException = new RuntimeException("Sync policy error"); + HttpPipelinePolicy failingPolicy = new FailingPolicy(policyException); + + PolicyDecoratingHttpClient client + = new PolicyDecoratingHttpClient(recordingClient, Collections.singletonList(failingPolicy)); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + RuntimeException thrown = assertThrows(RuntimeException.class, () -> client.sendSync(request, Context.NONE)); + assertEquals("Sync policy error", thrown.getMessage()); + assertEquals(0, recordingClient.getSendCount()); + } + + @Test + void underlyingClientErrorPropagatesInSyncSend() throws Exception { + RuntimeException clientException = new RuntimeException("Sync client error"); + FailingHttpClient failingClient = new FailingHttpClient(clientException); + + PolicyDecoratingHttpClient client = new PolicyDecoratingHttpClient(failingClient, Collections.emptyList()); + + HttpRequest request = new HttpRequest(HttpMethod.GET, URI.create("https://example.com").toURL()); + + RuntimeException thrown = assertThrows(RuntimeException.class, () -> client.sendSync(request, Context.NONE)); + assertEquals("Sync client error", thrown.getMessage()); + assertTrue(failingClient.wasCalled()); + } + private static final class RecordingHttpClient implements HttpClient { private HttpRequest lastRequest; private final AtomicInteger sendCount = new AtomicInteger(); @@ -127,4 +236,46 @@ public Mono process(HttpPipelineCallContext context, HttpPipelineN return next.process(); } } + + private static final class FailingPolicy implements HttpPipelinePolicy { + private final RuntimeException exception; + + private FailingPolicy(RuntimeException exception) { + this.exception = exception; + } + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + return Mono.error(exception); + } + + @Override + public HttpPipelinePosition getPipelinePosition() { + return HttpPipelinePosition.PER_CALL; + } + } + + private static final class FailingHttpClient implements HttpClient { + private final RuntimeException exception; + private boolean called = false; + + private FailingHttpClient(RuntimeException exception) { + this.exception = exception; + } + + @Override + public Mono send(HttpRequest request) { + this.called = true; + return Mono.error(exception); + } + + @Override + public Mono send(HttpRequest request, Context context) { + return send(request); + } + + boolean wasCalled() { + return called; + } + } }