diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java index 51866ba9dce7..c567a6b6d0d4 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java @@ -45,6 +45,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponseDecorator; @@ -101,7 +102,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp private final List defaultViews = new ArrayList<>(4); - private final List streamHandlers = List.of(new SseStreamHandler()); + private final SseStreamHandler sseHandler = new SseStreamHandler(); /** @@ -175,7 +176,7 @@ public boolean supports(HandlerResult result) { returnType = returnType.getNested(2); if (adapter.isMultiValue()) { - return Fragment.class.isAssignableFrom(type); + return (Fragment.class.isAssignableFrom(type) || isSseFragmentStream(returnType)); } } @@ -194,8 +195,13 @@ private boolean hasModelAnnotation(MethodParameter parameter) { } private static boolean isFragmentCollection(ResolvableType returnType) { - Class clazz = returnType.resolve(Object.class); - return (Collection.class.isAssignableFrom(clazz) && Fragment.class.equals(returnType.getNested(2).resolve())); + return (Collection.class.isAssignableFrom(returnType.resolve(Object.class)) && + Fragment.class.equals(returnType.getNested(2).resolve())); + } + + private static boolean isSseFragmentStream(ResolvableType returnType) { + return (ServerSentEvent.class.equals(returnType.resolve()) && + Fragment.class.equals(returnType.getNested(2).resolve())); } @Override @@ -204,9 +210,15 @@ public Mono handleResult(ServerWebExchange exchange, HandlerResult result) Mono valueMono; ResolvableType valueType; ReactiveAdapter adapter = getAdapter(result); + BindingContext bindingContext = result.getBindingContext(); + Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext()); if (adapter != null) { if (adapter.isMultiValue()) { + if (isSseFragmentStream(result.getReturnType().getNested(2))) { + return handleSseFragmentStream(exchange, result, adapter, locale, bindingContext); + } + valueMono = (result.getReturnValue() != null ? Mono.just(FragmentsRendering.withPublisher(adapter.toPublisher(result.getReturnValue())).build()) : Mono.empty()); @@ -233,8 +245,6 @@ public Mono handleResult(ServerWebExchange exchange, HandlerResult result) Mono> viewsMono; Model model = result.getModel(); MethodParameter parameter = result.getReturnTypeSource(); - BindingContext bindingContext = result.getBindingContext(); - Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext()); Class clazz = valueType.toClass(); if (clazz == Object.class) { @@ -277,13 +287,15 @@ else if (FragmentsRendering.class.isAssignableFrom(clazz)) { response.getHeaders().putAll(render.headers()); bindingContext.updateModel(exchange); - StreamHandler streamHandler = getStreamHandler(exchange); + StreamHandler streamHandler = + (this.sseHandler.supports(exchange.getRequest()) ? this.sseHandler : null); + if (streamHandler != null) { streamHandler.updateResponse(exchange); } Flux> renderFlux = render.fragments() - .concatMap(fragment -> renderFragment(fragment, streamHandler, locale, bindingContext, exchange)) + .concatMap(fragment -> renderFragment(fragment, null, streamHandler, locale, bindingContext, exchange)) .doOnDiscard(DataBuffer.class, DataBufferUtils::release); return response.writeAndFlushWith(renderFlux); @@ -338,9 +350,29 @@ private Mono> resolveViews(String viewName, Locale locale) { }); } + private Mono handleSseFragmentStream( + ServerWebExchange exchange, HandlerResult result, ReactiveAdapter adapter, Locale locale, + BindingContext bindingContext) { + + this.sseHandler.updateResponse(exchange); + + Flux> eventFlux = + Flux.from(adapter.toPublisher(result.getReturnValue())); + + Flux> dataBufferFlux = eventFlux + .concatMap(event -> renderFragment(event.data(), event, this.sseHandler, locale, bindingContext, exchange)) + .doOnDiscard(DataBuffer.class, DataBufferUtils::release); + + return exchange.getResponse().writeAndFlushWith(dataBufferFlux); + } + private Mono> renderFragment( - Fragment fragment, @Nullable StreamHandler streamHandler, Locale locale, - BindingContext bindingContext, ServerWebExchange exchange) { + @Nullable Fragment fragment, @Nullable Object streamingHints, @Nullable StreamHandler streamHandler, + Locale locale, BindingContext bindingContext, ServerWebExchange exchange) { + + if (fragment == null) { + return Mono.empty(); + } // Merge attributes from top-level model fragment.mergeAttributes(bindingContext.getModel()); @@ -355,8 +387,11 @@ private Mono> renderFragment( Map model = fragment.model(); if (streamHandler != null) { - return selectedViews.flatMap(views -> render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange)) - .then(Mono.fromSupplier(() -> streamHandler.format(response.getBodyFlux(), fragment, exchange))); + return selectedViews + .flatMap(views -> + render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange)) + .then(Mono.fromSupplier(() -> streamHandler.format( + response.getBodyFlux(), fragment, streamingHints, exchange))); } else { return selectedViews.flatMap(views -> render(views, model, null, bindingContext, mutatedExchange)) @@ -364,16 +399,6 @@ private Mono> renderFragment( } } - @Nullable - private StreamHandler getStreamHandler(ServerWebExchange exchange) { - for (StreamHandler handler : this.streamHandlers) { - if (handler.supports(exchange.getRequest())) { - return handler; - } - } - return null; - } - private String getNameForReturnValue(MethodParameter returnType) { return Optional.ofNullable(returnType.getMethodAnnotation(ModelAttribute.class)) .filter(ann -> StringUtils.hasText(ann.value())) @@ -499,10 +524,13 @@ private interface StreamHandler { * Format the given fragment. * @param fragmentContent the fragment serialized to data buffers * @param fragment the fragment being rendered + * @param streamingHints extra hints for the stream format (e.g. ServerSentEvent wrapper) * @param exchange the current exchange * @return the formatted fragment */ - Flux format(Flux fragmentContent, Fragment fragment, ServerWebExchange exchange); + Flux format( + Flux fragmentContent, Fragment fragment, @Nullable Object streamingHints, + ServerWebExchange exchange); } @@ -540,16 +568,21 @@ private Charset getCharset(ServerHttpRequest request) { @Override public Flux format( - Flux fragmentFlux, Fragment fragment, ServerWebExchange exchange) { + Flux fragmentFlux, Fragment fragment, @Nullable Object hints, + ServerWebExchange exchange) { MediaType mediaType = exchange.getResponse().getHeaders().getContentType(); Charset charset = (mediaType != null && mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8); + Assert.state(hints == null || hints instanceof ServerSentEvent, "Expected ServerSentEvent"); DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); - String eventLine = (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : ""); - DataBuffer prefix = encodeText(eventLine + "data:", charset, bufferFactory); + ServerSentEvent sse = (ServerSentEvent) hints; + CharSequence eventText = (sse != null ? sse.format() : + (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "") + "data:"); + + DataBuffer prefix = encodeText(eventText.toString(), charset, bufferFactory); DataBuffer suffix = encodeText("\n\n", charset, bufferFactory); Mono content = DataBufferUtils.join(fragmentFlux) diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java index 40122c3f7a21..236e85e3d704 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java @@ -35,7 +35,9 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.support.ResourceBundleMessageSource; import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.accept.HeaderContentTypeResolver; @@ -99,7 +101,51 @@ void render(Object returnValue, MethodParameter parameter) { } @Test - void renderSse() { + void renderFragmentStream() { + + testSse(Flux.just(fragment1, fragment2), + on(Handler.class).resolveReturnType(Flux.class, Fragment.class), + """ + event:fragment1 + data:

+ data: Hello Foo + data:

+ + event:fragment2 + data:

+ data: Hello Bar + data:

+ + """); + } + + @Test + void renderServerSentEventFragmentStream() { + + ServerSentEvent event1 = ServerSentEvent.builder(fragment1).id("id1").event("event1").build(); + ServerSentEvent event2 = ServerSentEvent.builder(fragment2).id("id2").event("event2").build(); + + MethodParameter returnType = on(Handler.class).resolveReturnType( + Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class)); + + testSse(Flux.just(event1, event2), returnType, + """ + id:id1 + event:event1 + data:

+ data: Hello Foo + data:

+ + id:id2 + event:event2 + data:

+ data: Hello Bar + data:

+ + """); + } + + private void testSse(Flux dataFlux, MethodParameter returnType, String output) { MockServerHttpRequest request = MockServerHttpRequest.get("/") .accept(MediaType.TEXT_EVENT_STREAM) .acceptLanguageAsLocales(Locale.ENGLISH) @@ -110,8 +156,8 @@ void renderSse() { HandlerResult result = new HandlerResult( new Handler(), - Flux.just(fragment1, fragment2).subscribeOn(Schedulers.boundedElastic()), - on(Handler.class).resolveReturnType(Flux.class, Fragment.class), + dataFlux.subscribeOn(Schedulers.boundedElastic()), + returnType, new BindingContext()); String body = initHandler().handleResult(exchange, result) @@ -119,18 +165,7 @@ void renderSse() { .block(Duration.ofSeconds(60)); assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM); - assertThat(body).isEqualTo(""" - event:fragment1 - data:

- data: Hello Foo - data:

- - event:fragment2 - data:

- data: Hello Bar - data:

- - """); + assertThat(body).isEqualTo(output); } private ViewResolutionResultHandler initHandler() { @@ -155,6 +190,8 @@ private static class Handler { Flux renderFlux() { return null; } + Flux> renderSseFlux() { return null; } + List renderList() { return null; } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java index 2ad247317b47..bc4249cb6210 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java @@ -41,6 +41,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.ui.ConcurrentModel; @@ -84,6 +85,9 @@ void supports() { testSupports(on(Handler.class).resolveReturnType(FragmentsRendering.class)); testSupports(on(Handler.class).resolveReturnType(Flux.class, Fragment.class)); + testSupports(on(Handler.class).resolveReturnType( + Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class))); + testSupports(on(Handler.class).resolveReturnType(List.class, Fragment.class)); testSupports(on(Handler.class).resolveReturnType( Mono.class, ResolvableType.forClassWithGenerics(List.class, Fragment.class))); @@ -457,6 +461,7 @@ private static class Handler { FragmentsRendering fragmentsRendering() { return null; } Flux fragmentFlux() { return null; } + Flux> fragmentServerSentEventFlux() { return null; } Mono> monoFragmentList() { return null; } List fragmentList() { return null; }