diff --git a/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/providers/serialisers/StreamingOutputMessageBodyWriter.java b/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/providers/serialisers/StreamingOutputMessageBodyWriter.java index 9fa833eccaf5d..ed3ad8e72ee0f 100644 --- a/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/providers/serialisers/StreamingOutputMessageBodyWriter.java +++ b/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/providers/serialisers/StreamingOutputMessageBodyWriter.java @@ -10,11 +10,13 @@ import jakarta.ws.rs.core.MultivaluedMap; import jakarta.ws.rs.core.StreamingOutput; +import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext; import org.jboss.resteasy.reactive.server.spi.ResteasyReactiveResourceInfo; import org.jboss.resteasy.reactive.server.spi.ServerMessageBodyWriter; import org.jboss.resteasy.reactive.server.spi.ServerRequestContext; public class StreamingOutputMessageBodyWriter implements ServerMessageBodyWriter { + @Override public boolean isWriteable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { return doIsWriteable(type); @@ -45,7 +47,25 @@ public void writeTo(StreamingOutput streamingOutput, Class type, Type generic @Override public void writeResponse(StreamingOutput o, Type genericType, ServerRequestContext context) - throws WebApplicationException, IOException { - o.write(context.getOrCreateOutputStream()); + throws WebApplicationException { + ResteasyReactiveRequestContext rrContext = (ResteasyReactiveRequestContext) context; + try { + o.write(context.getOrCreateOutputStream()); + } catch (Throwable t) { + if (context.serverResponse().headWritten()) { + context.serverResponse().reset(); + rrContext.resume(t); + } else { + if (t instanceof WebApplicationException) { + throw (WebApplicationException) t; + } else if (t instanceof IOException) { + throw new WebApplicationException(t); + } else if (t instanceof RuntimeException) { + throw new WebApplicationException(t); + } else { + throw new WebApplicationException(t); + } + } + } } } diff --git a/integration-tests/elytron-resteasy-reactive/src/main/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputResource.java b/integration-tests/elytron-resteasy-reactive/src/main/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputResource.java new file mode 100644 index 0000000000000..779eef10425e8 --- /dev/null +++ b/integration-tests/elytron-resteasy-reactive/src/main/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputResource.java @@ -0,0 +1,47 @@ +package io.quarkus.it.resteasy.reactive.elytron; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; + +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.StreamingOutput; + +import io.smallrye.common.annotation.Blocking; + +@Path("/streaming-output-error") +public class StreamingOutputResource { + private static final int ITEMS_PER_EMIT = 100; + + private static final byte[] CHUNK = "This is one chunk of data.\n".getBytes(StandardCharsets.UTF_8); + + @GET + @Path("/output") + @Produces(MediaType.TEXT_PLAIN) + @Blocking + public StreamingOutput streamOutput(@QueryParam("fail") @DefaultValue("false") boolean fail) { + return outputStream -> { + try { + writeData(outputStream); + if (fail) { + throw new IOException("dummy failure"); + } + writeData(outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } + + private void writeData(OutputStream out) throws IOException { + for (int i = 0; i < ITEMS_PER_EMIT; i++) { + out.write(CHUNK); + out.flush(); + } + } +} diff --git a/integration-tests/elytron-resteasy-reactive/src/test/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputErrorHandlingTest.java b/integration-tests/elytron-resteasy-reactive/src/test/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputErrorHandlingTest.java new file mode 100644 index 0000000000000..07dc900644b22 --- /dev/null +++ b/integration-tests/elytron-resteasy-reactive/src/test/java/io/quarkus/it/resteasy/reactive/elytron/StreamingOutputErrorHandlingTest.java @@ -0,0 +1,115 @@ +package io.quarkus.it.resteasy.reactive.elytron; + +import static io.restassured.RestAssured.port; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.quarkus.test.junit.QuarkusTest; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpClient; +import io.vertx.core.http.HttpClosedException; +import io.vertx.core.http.HttpMethod; + +@QuarkusTest +public class StreamingOutputErrorHandlingTest { + private static final Duration TIMEOUT = Duration.ofSeconds(10); + private static final int ITEMS_PER_BATCH = 100; + private static final int BYTES_PER_CHUNK = "This is one chunk of data.\n" + .getBytes(StandardCharsets.UTF_8).length; + private static final long EXPECTED_BYTES_FIRST_BATCH = (long) ITEMS_PER_BATCH * BYTES_PER_CHUNK; + private static final long EXPECTED_BYTES_COMPLETE = (long) ITEMS_PER_BATCH * 2 * BYTES_PER_CHUNK; + + private Vertx vertx; + private HttpClient client; + + @BeforeEach + public void setup() { + vertx = Vertx.vertx(); + client = vertx.createHttpClient(); + } + + @AfterEach + public void cleanup() throws Exception { + if (client != null) { + client.close().toCompletionStage().toCompletableFuture().get(5, TimeUnit.SECONDS); + } + if (vertx != null) { + vertx.close().toCompletionStage().toCompletableFuture().get(5, TimeUnit.SECONDS); + } + } + + @Test + public void testStreamingOutputFailureMidStream() { + AtomicLong byteCount = new AtomicLong(); + CompletableFuture latch = new CompletableFuture<>(); + + sendRequest("/streaming-output-error/output?fail=true", latch, + b -> byteCount.addAndGet(b.length())); + + Assertions.assertTimeoutPreemptively(TIMEOUT, () -> { + ExecutionException ex = Assertions.assertThrows(ExecutionException.class, + latch::get, + "Client should have detected that the server reset the connection"); + + Assertions.assertInstanceOf(HttpClosedException.class, ex.getCause(), + "Expected HttpClosedException when connection is reset mid-stream"); + }); + + Assertions.assertEquals(EXPECTED_BYTES_FIRST_BATCH, byteCount.get(), + "Should have received only the first batch of data before failure"); + } + + @Test + public void testStreamingOutputSuccess() { + AtomicLong byteCount = new AtomicLong(); + CompletableFuture latch = new CompletableFuture<>(); + + sendRequest("/streaming-output-error/output?fail=false", latch, + b -> byteCount.addAndGet(b.length())); + + Assertions.assertTimeoutPreemptively(TIMEOUT, + () -> latch.get(), + "StreamingOutput should complete successfully without errors"); + + Assertions.assertEquals(EXPECTED_BYTES_COMPLETE, byteCount.get(), + "Should have received all data when no errors occur"); + } + + private void sendRequest(String uri, CompletableFuture latch, Consumer bodyConsumer) { + Handler failureHandler = latch::completeExceptionally; + + client.request(HttpMethod.GET, port, "localhost", uri) + .onFailure(failureHandler) + .onSuccess(request -> { + request.end(); + request.connect() + .onFailure(failureHandler) + .onSuccess(response -> { + response.request().connection().closeHandler(v -> { + failureHandler.handle(new HttpClosedException("Connection was closed")); + }); + + response.handler(buffer -> { + if (buffer.length() > 0) { + bodyConsumer.accept(buffer); + } + }); + response.exceptionHandler(failureHandler); + response.endHandler(v -> latch.complete(null)); + }); + }); + } +}