Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.agentclientprotocol.model.SessionUpdate
import com.agentclientprotocol.model.StopReason
import com.agentclientprotocol.model.ToolCallId
import com.agentclientprotocol.protocol.invoke
import io.github.oshai.kotlinlogging.KotlinLogging.logger
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.awaitCancellation
Expand Down Expand Up @@ -140,6 +141,80 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
assertEquals(result!!.stopReason, StopReason.END_TURN)
}

@Test
fun `prompt response and update have proper order`() = testWithProtocols { clientProtocol, agentProtocol ->
val client = Client(protocol = clientProtocol)
val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
return AgentInfo(clientInfo.protocolVersion)
}

override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
return object : AgentSession {
override val sessionId: SessionId = SessionId("test-session-id")

override suspend fun prompt(
content: List<ContentBlock>,
_meta: JsonElement?,
): Flow<Event> = flow {
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text(sessionParameters.cwd))))
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 1"))))
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 2"))))
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 3"))))
}
}
}

override suspend fun loadSession(
sessionId: SessionId,
sessionParameters: SessionCreationParameters,
): AgentSession {
TODO("Not yet implemented")
}
})
val testVersion = 10
val clientInfo = ClientInfo(protocolVersion = testVersion)
val agentInfo = client.initialize(clientInfo)
val cwd = "/test/path"
val newSession = client.newSession(SessionCreationParameters(cwd, emptyList())) { _, _ ->
object : ClientSessionOperations {
override suspend fun requestPermissions(
toolCall: SessionUpdate.ToolCallUpdate,
permissions: List<PermissionOption>,
_meta: JsonElement?,
): RequestPermissionResponse {
TODO("Not yet implemented")
}

override suspend fun notify(
notification: SessionUpdate,
_meta: JsonElement?,
) {
TODO("Not yet implemented")
}
}
}
val responses = mutableListOf<String>()
var result: PromptResponse? = null
withTimeout(1000) {
newSession.prompt(listOf()).collect { event ->
when (event) {
is Event.PromptResponseEvent -> {
println( "Received prompt response: ${event.response}" )
result = event.response
responses.add(event.response.stopReason.toString())
}
is Event.SessionUpdateEvent -> {
println( "Received session update: ${(event.update as SessionUpdate.AgentMessageChunk).content}" )
responses.add(((event.update as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text)
}
}
}
}
assertContentEquals(listOf("/test/path", "text 1", "text 2", "text 3", "END_TURN"), responses)
assertEquals(result!!.stopReason, StopReason.END_TURN)
}

@Test
fun `cancel simple prompt from client`() = testWithProtocols { clientProtocol, agentProtocol ->
val client = Client(protocol = clientProtocol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.agentclientprotocol.protocol.Protocol
import com.agentclientprotocol.protocol.invoke
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
Expand Down Expand Up @@ -62,12 +63,20 @@ internal class ClientSessionImpl(
logger.trace { "Sending prompt request: $content" }
val promptResponse = AcpMethod.AgentMethods.SessionPrompt(protocol, PromptRequest(sessionId, content, _meta))
logger.trace { "Received prompt response: $promptResponse" }
send(Event.PromptResponseEvent(promptResponse))
} finally {

// after receiving prompt response we immediately close the current prompt channel
// and then waiting for draining all the updates that were sent during prompt execution
// only after that we emit the PromptResponseEvent to the outbound flow
logger.trace { "Closing prompt channel" }
activePrompt.getAndSet(null)?.updateChannel?.close()
logger.trace { "Waiting for prompt channel to close" }
channelJob.join()

send(Event.PromptResponseEvent(promptResponse))
close()
} finally {
activePrompt.getAndSet(null)?.updateChannel?.close()
channelJob.cancel()
}
}

Expand Down Expand Up @@ -123,7 +132,9 @@ internal class ClientSessionImpl(
// }

val promptSession = activePrompt.value
if (promptSession != null) {
@OptIn(DelicateCoroutinesApi::class)
// check for isClosedForSend because the prompt may exist, but the code is waiting for the updates drain
if (promptSession != null && !promptSession.updateChannel.isClosedForSend) {
logger.trace { "Sending update to active prompt: $notification" }
promptSession.updateChannel.send(notification)
}
Expand Down
Loading