diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt index e5456d1..ebad6c9 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt @@ -26,6 +26,7 @@ import com.agentclientprotocol.model.SessionUpdate import com.agentclientprotocol.model.StopReason import com.agentclientprotocol.model.ToolCallId import com.agentclientprotocol.protocol.invoke +import io.github.oshai.kotlinlogging.KotlinLogging.logger import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.awaitCancellation @@ -140,6 +141,80 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver assertEquals(result!!.stopReason, StopReason.END_TURN) } + @Test + fun `prompt response and update have proper order`() = testWithProtocols { clientProtocol, agentProtocol -> + val client = Client(protocol = clientProtocol) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + return object : AgentSession { + override val sessionId: SessionId = SessionId("test-session-id") + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = flow { + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text(sessionParameters.cwd)))) + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 1")))) + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 2")))) + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 3")))) + } + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + }) + val testVersion = 10 + val clientInfo = ClientInfo(protocolVersion = testVersion) + val agentInfo = client.initialize(clientInfo) + val cwd = "/test/path" + val newSession = client.newSession(SessionCreationParameters(cwd, emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + TODO("Not yet implemented") + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + TODO("Not yet implemented") + } + } + } + val responses = mutableListOf() + var result: PromptResponse? = null + withTimeout(1000) { + newSession.prompt(listOf()).collect { event -> + when (event) { + is Event.PromptResponseEvent -> { + println( "Received prompt response: ${event.response}" ) + result = event.response + responses.add(event.response.stopReason.toString()) + } + is Event.SessionUpdateEvent -> { + println( "Received session update: ${(event.update as SessionUpdate.AgentMessageChunk).content}" ) + responses.add(((event.update as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text) + } + } + } + } + assertContentEquals(listOf("/test/path", "text 1", "text 2", "text 3", "END_TURN"), responses) + assertEquals(result!!.stopReason, StopReason.END_TURN) + } + @Test fun `cancel simple prompt from client`() = testWithProtocols { clientProtocol, agentProtocol -> val client = Client(protocol = clientProtocol) diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt index 7e74bf5..cd3f268 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt @@ -8,6 +8,7 @@ import com.agentclientprotocol.protocol.Protocol import com.agentclientprotocol.protocol.invoke import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.atomicfu.atomic +import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow @@ -62,12 +63,20 @@ internal class ClientSessionImpl( logger.trace { "Sending prompt request: $content" } val promptResponse = AcpMethod.AgentMethods.SessionPrompt(protocol, PromptRequest(sessionId, content, _meta)) logger.trace { "Received prompt response: $promptResponse" } - send(Event.PromptResponseEvent(promptResponse)) - } finally { + + // after receiving prompt response we immediately close the current prompt channel + // and then waiting for draining all the updates that were sent during prompt execution + // only after that we emit the PromptResponseEvent to the outbound flow logger.trace { "Closing prompt channel" } activePrompt.getAndSet(null)?.updateChannel?.close() logger.trace { "Waiting for prompt channel to close" } channelJob.join() + + send(Event.PromptResponseEvent(promptResponse)) + close() + } finally { + activePrompt.getAndSet(null)?.updateChannel?.close() + channelJob.cancel() } } @@ -123,7 +132,9 @@ internal class ClientSessionImpl( // } val promptSession = activePrompt.value - if (promptSession != null) { + @OptIn(DelicateCoroutinesApi::class) + // check for isClosedForSend because the prompt may exist, but the code is waiting for the updates drain + if (promptSession != null && !promptSession.updateChannel.isClosedForSend) { logger.trace { "Sending update to active prompt: $notification" } promptSession.updateChannel.send(notification) }