Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-7327 Support conversion between byte channel interfaces and kotlinx-io primitives #4594

Merged
merged 10 commits into from
Jan 14, 2025
13 changes: 13 additions & 0 deletions ktor-io/api/ktor-io.api
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ public final class io/ktor/utils/io/ByteReadChannelOperations_jvmKt {
public static final fun skipDelimiter (Lio/ktor/utils/io/ByteReadChannel;Lkotlinx/io/bytestring/ByteString;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/ktor/utils/io/ByteReadChannelSourceKt {
public static final fun asSource (Lio/ktor/utils/io/ByteReadChannel;)Lkotlinx/io/RawSource;
}

public abstract interface class io/ktor/utils/io/ByteWriteChannel {
public abstract fun cancel (Ljava/lang/Throwable;)V
public abstract fun flush (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down Expand Up @@ -173,6 +177,10 @@ public final class io/ktor/utils/io/ByteWriteChannelOperations_jvmKt {
public static final fun writeFully (Lio/ktor/utils/io/ByteWriteChannel;Ljava/nio/ByteBuffer;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/ktor/utils/io/ByteWriteChannelSinkKt {
public static final fun asSink (Lio/ktor/utils/io/ByteWriteChannel;)Lkotlinx/io/RawSink;
}

public abstract interface class io/ktor/utils/io/ChannelJob {
public abstract fun getJob ()Lkotlinx/coroutines/Job;
}
Expand Down Expand Up @@ -279,6 +287,10 @@ public final class io/ktor/utils/io/ReaderScope : kotlinx/coroutines/CoroutineSc
public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext;
}

public final class io/ktor/utils/io/SinkByteWriteChannelKt {
public static final fun asByteWriteChannel (Lkotlinx/io/RawSink;)Lio/ktor/utils/io/ByteWriteChannel;
}

public final class io/ktor/utils/io/WriterJob : io/ktor/utils/io/ChannelJob {
public final fun getChannel ()Lio/ktor/utils/io/ByteReadChannel;
public fun getJob ()Lkotlinx/coroutines/Job;
Expand Down Expand Up @@ -594,6 +606,7 @@ public abstract class io/ktor/utils/io/pool/SingleInstancePool : io/ktor/utils/i
}

public final class io/ktor/utils/io/streams/StreamsKt {
public static final fun asByteWriteChannel (Ljava/io/OutputStream;)Lio/ktor/utils/io/ByteWriteChannel;
public static final fun asInput (Ljava/io/InputStream;)Lkotlinx/io/Source;
public static final fun inputStream (Lkotlinx/io/Source;)Ljava/io/InputStream;
public static final fun readPacketAtLeast (Ljava/io/InputStream;I)Lkotlinx/io/Source;
Expand Down
7 changes: 7 additions & 0 deletions ktor-io/api/ktor-io.klib.api
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ final fun (kotlinx.coroutines/CoroutineScope).io.ktor.utils.io/writer(kotlin.cor
final fun (kotlinx.coroutines/CoroutineScope).io.ktor.utils.io/writer(kotlin.coroutines/CoroutineContext = ..., kotlin/Boolean = ..., kotlin.coroutines/SuspendFunction1<io.ktor.utils.io/WriterScope, kotlin/Unit>): io.ktor.utils.io/WriterJob // io.ktor.utils.io/writer|[email protected](kotlin.coroutines.CoroutineContext;kotlin.Boolean;kotlin.coroutines.SuspendFunction1<io.ktor.utils.io.WriterScope,kotlin.Unit>){}[0]
final fun (kotlinx.io/Buffer).io.ktor.utils.io.core/canRead(): kotlin/Boolean // io.ktor.utils.io.core/canRead|[email protected](){}[0]
final fun (kotlinx.io/Buffer).io.ktor.utils.io.core/readBytes(kotlin/Int = ...): kotlin/ByteArray // io.ktor.utils.io.core/readBytes|[email protected](kotlin.Int){}[0]
final fun (kotlinx.io/RawSink).io.ktor.utils.io/asByteWriteChannel(): io.ktor.utils.io/ByteWriteChannel // io.ktor.utils.io/asByteWriteChannel|[email protected](){}[0]
final fun (kotlinx.io/Sink).io.ktor.utils.io.core/append(kotlin/CharSequence, kotlin/Int = ..., kotlin/Int = ...) // io.ktor.utils.io.core/append|[email protected](kotlin.CharSequence;kotlin.Int;kotlin.Int){}[0]
final fun (kotlinx.io/Sink).io.ktor.utils.io.core/build(): kotlinx.io/Source // io.ktor.utils.io.core/build|[email protected](){}[0]
final fun (kotlinx.io/Sink).io.ktor.utils.io.core/writeFully(kotlin/ByteArray, kotlin/Int = ..., kotlin/Int = ...) // io.ktor.utils.io.core/writeFully|[email protected](kotlin.ByteArray;kotlin.Int;kotlin.Int){}[0]
Expand Down Expand Up @@ -604,6 +605,12 @@ sealed class io.ktor.utils.io.errors/PosixException : kotlin/Exception { // io.k
}
}

// Targets: [native]
final fun (io.ktor.utils.io/ByteReadChannel).io.ktor.utils.io/asSource(): kotlinx.io/RawSource // io.ktor.utils.io/asSource|[email protected](){}[0]

// Targets: [native]
final fun (io.ktor.utils.io/ByteWriteChannel).io.ktor.utils.io/asSink(): kotlinx.io/RawSink // io.ktor.utils.io/asSink|[email protected](){}[0]

// Targets: [native]
final fun (kotlinx.io/Sink).io.ktor.utils.io.core/write(kotlin/Function3<kotlinx.cinterop/CPointer<kotlinx.cinterop/ByteVarOf<kotlin/Byte>>, kotlin/Long, kotlin/Long, kotlin/Long>): kotlin/Long // io.ktor.utils.io.core/write|[email protected](kotlin.Function3<kotlinx.cinterop.CPointer<kotlinx.cinterop.ByteVarOf<kotlin.Byte>>,kotlin.Long,kotlin.Long,kotlin.Long>){}[0]

Expand Down
4 changes: 2 additions & 2 deletions ktor-io/common/src/io/ktor/utils/io/CloseToken.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

package io.ktor.utils.io

import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CopyableThrowable
import kotlinx.io.IOException

internal val CLOSED = CloseToken(null)
Expand Down
64 changes: 64 additions & 0 deletions ktor-io/common/src/io/ktor/utils/io/SinkByteWriteChannel.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.utils.io

import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.io.*

/**
* Creates a [ByteWriteChannel] that writes to this [Sink].
*
* Example usage:
* ```kotlin
* suspend fun writeMessage(raw: RawSink) {
* val channel = raw.asByteWriteChannel()
* channel.writeByte(42)
* channel.flushAndClose()
* }
*
* val buffer = Buffer()
* writeMessage(buffer)
* buffer.readByte() // 42
* ```
*
* Please note that the channel will be buffered even if the sink is not.
*/
public fun RawSink.asByteWriteChannel(): ByteWriteChannel = SinkByteWriteChannel(this)

internal class SinkByteWriteChannel(origin: RawSink) : ByteWriteChannel {
val closed: AtomicRef<CloseToken?> = atomic(null)
private val buffer = origin.buffered()

override val isClosedForWrite: Boolean
get() = closed.value != null

override val closedCause: Throwable?
get() = closed.value?.cause

@InternalAPI
override val writeBuffer: Sink
get() {
if (isClosedForWrite) throw closedCause ?: IOException("Channel is closed for write")
return buffer
}

@OptIn(InternalAPI::class)
override suspend fun flush() {
writeBuffer.flush()
}

@OptIn(InternalAPI::class)
override suspend fun flushAndClose() {
writeBuffer.flush()
e5l marked this conversation as resolved.
Show resolved Hide resolved
if (!closed.compareAndSet(expect = null, update = CLOSED)) return
}

@OptIn(InternalAPI::class)
override fun cancel(cause: Throwable?) {
val token = if (cause == null) CLOSED else CloseToken(cause)
if (!closed.compareAndSet(expect = null, update = token)) return
}
}
54 changes: 25 additions & 29 deletions ktor-io/common/test/ByteChannelTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,79 @@
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

import io.ktor.test.dispatcher.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlinx.io.*
import kotlinx.coroutines.test.runTest
import kotlinx.io.EOFException
import kotlinx.io.IOException
import kotlin.test.*

class ByteChannelTest {

@Test
fun testReadFromEmpty() = testSuspend {
fun `write after close should fail`() = runTest {
val channel = ByteChannel()
channel.flushAndClose()
channel.close()
assertFailsWith<IOException> {
channel.writeByte(1)
}
}

@Test
fun testReadFromEmpty() = runTest {
val channel = ByteChannel()
channel.flushAndClose()
assertFailsWith<EOFException> {
channel.readByte()
}
}

@Test
fun testWriteReadByte() = testSuspend {
fun testWriteReadByte() = runTest {
val channel = ByteChannel()
channel.writeByte(42)
channel.flushAndClose()
assertEquals(42, channel.readByte())
}

@Test
fun testCancel() = testSuspend {
fun testCancel() = runTest {
val channel = ByteChannel()
channel.cancel()

assertFailsWith<IOException> {
channel.readByte()
}
}

@Test
fun testWriteInClosedChannel() = testSuspend {
fun testWriteInClosedChannel() = runTest {
val channel = ByteChannel()
channel.flushAndClose()

assertTrue(channel.isClosedForWrite)
assertFailsWith<ClosedWriteChannelException> {
channel.writeByte(42)
}
}

@Test
fun testCreateFromArray() = testSuspend {
fun testCreateFromArray() = runTest {
val array = byteArrayOf(1, 2, 3, 4, 5)
val channel = ByteReadChannel(array)
val result = channel.toByteArray()
assertTrue(array.contentEquals(result))
}

@Test
fun testChannelFromString() = testSuspend {
fun testChannelFromString() = runTest {
val string = "Hello, world!"
val channel = ByteReadChannel(string)
val result = channel.readRemaining().readText()
assertEquals(string, result)
}

@Test
fun testCancelByteReadChannel() = testSuspend {
fun testCancelByteReadChannel() = runTest {
val channel = ByteReadChannel(byteArrayOf(1, 2, 3, 4, 5))
channel.cancel()
assertFailsWith<IOException> {
Expand All @@ -76,78 +83,67 @@ class ByteChannelTest {
}

@Test
fun testCloseAfterAwait() = testSuspend {
fun testCloseAfterAwait() = runTest {
val channel = ByteChannel()
val job = launch(start = CoroutineStart.UNDISPATCHED) {
channel.awaitContent()
}

channel.flushAndClose()
job.join()
}

@Test
fun testChannelMaxSize() = testSuspend(timeoutMillis = 1000) {
fun testChannelMaxSize() = runTest {
val channel = ByteChannel()
val job = launch(Dispatchers.Unconfined) {
channel.writeFully(ByteArray(CHANNEL_MAX_SIZE))
}

delay(100)
assertFalse(job.isCompleted)

channel.readByte()
job.join()
}

@Test
fun testChannelMaxSizeWithException() = testSuspend {
fun testChannelMaxSizeWithException() = runTest {
val channel = ByteChannel()
var writerThrows = false
val deferred = async(Dispatchers.Unconfined) {
try {
channel.writeFully(ByteArray(CHANNEL_MAX_SIZE))
} catch (cause: IOException) {
} catch (_: IOException) {
writerThrows = true
}
}

assertFalse(deferred.isCompleted)

channel.cancel()
deferred.await()

assertTrue(writerThrows)
}

@Test
fun testIsCloseForReadAfterCancel() = testSuspend {
fun testIsCloseForReadAfterCancel() = runTest {
val packet = buildPacket {
writeInt(1)
writeInt(2)
writeInt(3)
}

val channel = ByteChannel()
channel.writePacket(packet)
channel.flush()
channel.cancel()

assertTrue(channel.isClosedForRead)
}

@Test
fun testWriteAndFlushResumeReader() = testSuspend {
fun testWriteAndFlushResumeReader() = runTest {
val channel = ByteChannel()
val reader = async {
channel.readByte()
}

yield()

channel.writeByte(42)
channel.flush()

assertEquals(42, reader.await())
}
}
Loading
Loading