diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt index 7bca596a0109..3b6801e7b1db 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,9 @@ package org.springframework.web.reactive.function.client import io.mockk.every import io.mockk.mockk +import io.mockk.slot import io.mockk.verify +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList @@ -32,6 +34,8 @@ import reactor.core.publisher.Flux import reactor.core.publisher.Mono import java.util.concurrent.CompletableFuture import java.util.function.Function +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext /** * Mock object based tests for [WebClient] Kotlin extensions @@ -103,6 +107,18 @@ class WebClientExtensionsTests { } } + @Test + fun `awaitExchange with coroutines context`() { + val foo = mockk() + val slot = slot>>() + every { requestBodySpec.exchangeToMono(capture(slot)) } answers { + slot.captured.apply(mockk()) + } + runBlocking(FooContextElement(foo)) { + assertThat(requestBodySpec.awaitExchange { currentCoroutineContext()[FooContextElement]!!.foo }).isEqualTo(foo) + } + } + @Test fun `awaitExchangeOrNull returning null`() { val foo = mockk() @@ -121,6 +137,18 @@ class WebClientExtensionsTests { } } + @Test + fun `awaitExchangeOrNull with coroutines context`() { + val foo = mockk() + val slot = slot>>() + every { requestBodySpec.exchangeToMono(capture(slot)) } answers { + slot.captured.apply(mockk()) + } + runBlocking(FooContextElement(foo)) { + assertThat(requestBodySpec.awaitExchangeOrNull { currentCoroutineContext()[FooContextElement]!!.foo }).isEqualTo(foo) + } + } + @Test fun exchangeToFlow() { val foo = mockk() @@ -202,4 +230,8 @@ class WebClientExtensionsTests { } class Foo + + private data class FooContextElement(val foo: Foo) : AbstractCoroutineContextElement(FooContextElement) { + companion object Key : CoroutineContext.Key + } }