Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {

private val serverSupportedPlugins: CompletableDeferred<Set<KrpcPlugin>> = CompletableDeferred()

private val requestChannels = RpcInternalConcurrentHashMap<String, Channel<Any?>>()
private val requestChannels = RpcInternalConcurrentHashMap<String, Channel<Result<Any?>>>()

@InternalRpcApi
final override val supportedPlugins: Set<KrpcPlugin>
Expand Down Expand Up @@ -247,11 +247,11 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {

val callId = "$connectionId:${callable.name}:$id"

val channel = Channel<T>()
val channel = Channel<Result<T>>()

try {
@Suppress("UNCHECKED_CAST")
requestChannels[callId] = channel as Channel<Any?>
requestChannels[callId] = channel as Channel<Result<Any?>>

val request = serializeRequest(
callId = callId,
Expand Down Expand Up @@ -308,7 +308,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
}
}

private suspend fun <T> FlowCollector<T>.consumeAndEmitServerMessages(channel: Channel<T>) {
private suspend fun <T> FlowCollector<T>.consumeAndEmitServerMessages(channel: Channel<Result<T>>) {
while (true) {
val element = channel.receiveCatching()
if (element.isClosed) {
Expand All @@ -317,14 +317,22 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
}

if (!element.isFailure) {
emit(element.getOrThrow())
val result = element.getOrThrow()
result.fold(
onSuccess = { value ->
emit(value)
},
onFailure = { throwable ->
throw throwable
}
)
}
}
}

private suspend fun <T, @Rpc R : Any> handleServerStreamingMessage(
message: KrpcCallMessage,
channel: Channel<T>,
channel: Channel<Result<T>>,
callable: RpcCallable<R>,
) {
when (message) {
Expand Down Expand Up @@ -355,7 +363,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
}

@Suppress("UNCHECKED_CAST")
channel.send(value.getOrNull() as T)
channel.send(value as Result<T>)
}

is KrpcCallMessage.StreamFinished -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ interface KrpcTestService {
suspend fun nullableReturn(returnNull: Boolean): TestClass?
suspend fun variance(arg2: TestList<in TestClass>, arg3: TestList2<TestClass>): TestList<out TestClass>?
suspend fun collectOnce(flow: Flow<String>)
suspend fun returnTestClassThatThrowsWhileDeserialization(value: Int): TestClassThatThrowsWhileDeserialization

suspend fun nonSerializableClass(localDate: LocalDate): LocalDate
suspend fun nonSerializableClassWithSerializer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class KrpcTestServiceBackend : KrpcTestService {
return arg1
}

override suspend fun returnTestClassThatThrowsWhileDeserialization(value: Int): TestClassThatThrowsWhileDeserialization {
return TestClassThatThrowsWhileDeserialization(value)
}

override suspend fun nullableParam(arg1: String?): String {
return arg1 ?: "null"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import kotlinx.rpc.krpc.server.KrpcServer
import kotlinx.rpc.registerService
import kotlinx.rpc.withService
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
Expand Down Expand Up @@ -460,6 +461,16 @@ abstract class KrpcTransportTestBase {
fun testUnitFlow() = runTest {
assertEquals(Unit, client.unitFlow().toList().single())
}

@Test
fun testPR445() = runTest {
assertFailsWith<SerializationException> {
val result = client.returnTestClassThatThrowsWhileDeserialization(42)
@Suppress("SENSELESS_COMPARISON")
if (result == null)
println("result must not be null")
}
}
}

private val JS_EXTENDED_TIMEOUT = if (isJs) 300.seconds else 60.seconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

package kotlinx.rpc.krpc.test

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

@Suppress("EqualsOrHashCode", "detekt.EqualsWithHashCodeExist")
@Serializable
Expand All @@ -15,6 +20,27 @@ open class TestClass(val value: Int = 0) {
}
}

@Suppress("EqualsOrHashCode", "detekt.EqualsWithHashCodeExist")
@Serializable(with = TestClassThatThrowsWhileDeserialization.Serializer::class)
class TestClassThatThrowsWhileDeserialization(val value: Int = 0) {
object Serializer : KSerializer<TestClassThatThrowsWhileDeserialization> {
override val descriptor = Int.serializer().descriptor

override fun serialize(encoder: Encoder, value: TestClassThatThrowsWhileDeserialization) {
encoder.encodeInt(value.value)
}

override fun deserialize(decoder: Decoder): TestClassThatThrowsWhileDeserialization {
throw SerializationException("Its TestClassThatThrowsWhileDeserialization")
}
}

override fun equals(other: Any?): Boolean {
if (other !is TestClassThatThrowsWhileDeserialization) return false
return value == other.value
}
}

@Serializable
data class TestList<@Suppress("unused") T : TestClass>(val value: Int = 42)

Expand Down