Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve logging, and setup example module #17

Merged
merged 2 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ kotlin {
implementation(libs.kotest.assertions.arrow)
}
}

val jvmMain by getting {
dependencies {
implementation(libs.hikari)
implementation(libs.postgresql)
implementation(libs.ktor.client.cio)
api(libs.ktor.client.cio)
}
}

val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
Expand All @@ -87,24 +89,25 @@ kotlin {

val jsMain by getting {
dependencies {
implementation(libs.ktor.client.js)
api(libs.ktor.client.js)
}
}

val linuxX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
api(libs.ktor.client.cio)
}
}

val macosX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
api(libs.ktor.client.cio)
}
}

val mingwX64Main by getting {
dependencies {
implementation(libs.ktor.client.winhttp)
api(libs.ktor.client.winhttp)
}
}

Expand Down
22 changes: 22 additions & 0 deletions example/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
plugins {
id(libs.plugins.kotlin.jvm.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
}

repositories {
mavenCentral()
}

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
toolchain {
languageVersion = JavaLanguageVersion.of(11)
}
}

dependencies {
implementation(rootProject)
implementation(libs.kotlinx.serialization.json)
implementation(libs.logback)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xebia.functional.examples.auto
package com.xebia.functional.langchain4k.auto

import com.xebia.functional.auto.ai
import kotlinx.serialization.Serializable
Expand Down
11 changes: 11 additions & 0 deletions example/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>

<root level="DEBUG">
<appender-ref ref="STDOUT" />
</root>
</configuration>
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging"
hikari = { module = "com.zaxxer:HikariCP", version.ref = "hikari" }
postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" }
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }
logback = { module = "ch.qos.logback:logback-classic", version = "1.4.6" }

[bundles]
ktor-client = [
Expand All @@ -50,6 +51,7 @@ ktor-client = [
]

[plugins]
kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" }
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlinx-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
spotless = { id = "com.diffplug.spotless", version.ref = "spotless" }
Expand Down
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pluginManagement {

}
rootProject.name = "langchain4k"

include("example")
70 changes: 24 additions & 46 deletions src/commonMain/kotlin/com/xebia/functional/auto/AutoAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.Role
import com.xebia.functional.vectorstores.VectorStore

private const val COMPLETED = "%COMPLETED%"
private const val FAILED = "%FAILED%"
const val COMPLETED = "%COMPLETED%"
const val FAILED = "%FAILED%"

class AutoAI(
private val model: LLM,
Expand All @@ -39,8 +39,7 @@ class AutoAI(
|#. Second task""".trimMargin()

val response = chatCompletionResponse(prompt)
val firstMessage = getFirstMessage(response)
val newTasks = messageToTaskAsStrings(firstMessage)
val newTasks = messageToTaskAsStrings(response.firstChoiceOrNull())

return newTasks.mapNotNull {
val taskParts = it.trim().split(".", limit = 2)
Expand All @@ -53,18 +52,12 @@ class AutoAI(
}

private fun List<TaskWithResult>.print(): String =
joinToString("; ") { "${it.task.id.id}. ${it.task.objective.value} -> result: ${it.result.value}" }
joinToString("; ") { "${it.task.id.id}. ${it.task.objective.value} -> result: ${it.result.value()}" }

/**
* The execution agent is the AI that performs the task
*/
private suspend fun executionAgent(objective: Objective, task: Task): ChatCompletionResponse {
logger.debug {
"""
|Objective: ${objective.value}
|Task: $task
""".trimMargin()
}
val context = vectorStore.similaritySearch(objective.value, 5).map { TaskWithResult.fromJson(it.content) }
val prompt = """
|You are an AI who performs one task based on the following objective:
Expand All @@ -80,25 +73,34 @@ class AutoAI(
return chatCompletionResponse(prompt)
}

/**
* Call the remote Open AI API to complete the task
*/
/** Call the remote Open AI API to complete the task */
private suspend fun chatCompletionResponse(prompt: String): ChatCompletionResponse {
val completionRequest = ChatCompletionRequest(
model = model.value,
messages = listOf(Message(Role.system.name, prompt, user.name)),
user = user.name
)
return openAIClient.createChatCompletion(completionRequest)
return openAIClient.createChatCompletion(completionRequest).also {
logger.debug {
when {
it.isCompleted() -> "CreateChatCompletion SUCCESS ${it.firstChoiceOrNull()}"
it.isFailed() -> "CreateChatCompletion FAILED ${it.firstChoiceOrNull()}"
else -> "CreateChatCompletion No choices $it"
}
}
}
}

suspend operator fun invoke(objective: Objective): TaskResult? =
invoke(objective, nonEmptyListOf(Task(TaskId(1), objective)))

tailrec suspend operator fun invoke(objective: Objective, tasks: NonEmptyList<Task>): TaskResult? {
logger.debug { tasks.joinToString(separator = "\n") { "${it.id.id}. ${it.objective.value}" } }
val (resultMessage, task) = executeAndStoreTask(objective, tasks.head)
return if (taskHasCompleted(resultMessage)) task.result
val result = executionAgent(objective = objective, task = tasks.head)
val taskResult = requireNotNull(result.toTaskResultOrNull()) { "No message returned" }
val task = TaskWithResult(tasks.head, taskResult)
vectorStore.addText(task.toJson())
return if (task.isCompleted()) task.result
// Otherwise, send the result to the task creation agent to create new tasks
else when (val newPrompts = getNewTasksOrComplete(objective, tasks.tail, task)) {
// If the task creation agent determines the objective has been completed, return the results
Expand All @@ -107,8 +109,8 @@ class AutoAI(
is Either.Right -> {
var taskCounter = tasks.last().id.id
val newTasks = newPrompts.value.map { content -> Task(TaskId(taskCounter++), Objective(content)) }
val nel = prioritizationAgent(objective, newTasks).toNonEmptyListOrNull()
if (nel != null) invoke(objective, nel) else null
val tasksOrNull = prioritizationAgent(objective, newTasks).toNonEmptyListOrNull()
if (tasksOrNull != null) invoke(objective, tasksOrNull) else null
}
}
}
Expand All @@ -124,39 +126,15 @@ class AutoAI(
): Either<TaskCompleted, List<String>> = either {
val prompt = """
You are a task creation AI that uses the result of an execution agent to create new tasks with the following objective: ${objective.value},
The last completed task has the result: ${taskWithResult.result.value}.
The last completed task has the result: ${taskWithResult.value()}.
This result was based on this task description: ${taskWithResult.task.objective}.
These are incomplete tasks:
${tasks.joinToString(separator = "\n")}.
Based on the result, create new tasks to be completed by the AI system that do not overlap with incomplete tasks.
Return the tasks as an array.
IMPORTANT!!! : If there are no new tasks to complete and you determine the original objective:[${objective.value}] has been accomplished simply return:$COMPLETED""".trimIndent()
val response = chatCompletionResponse(prompt)
val resultMessage = getFirstMessage(response)
ensure(!taskHasCompleted(resultMessage)) { TaskCompleted }
messageToTaskAsStrings(resultMessage)
ensure(!response.isCompleted()) { TaskCompleted }
messageToTaskAsStrings(response.firstChoiceOrNull())
}

/** Check if the task has been completed */
private fun taskHasCompleted(result: String?): Boolean =
result?.endsWith(COMPLETED) == true

/** Execute the task and store the result */
private suspend fun executeAndStoreTask(objective: Objective, task: Task): Pair<String, TaskWithResult> {
val result = executionAgent(objective = objective, task = task)
val firstMessage = requireNotNull(getFirstMessage(result)) { "No message returned" }
val cleanedMessage = cleanResultMessage(firstMessage)
val taskWithResult = TaskWithResult(task, TaskResult(cleanedMessage))
vectorStore.addText(taskWithResult.toJson())
return Pair(firstMessage, taskWithResult)
}

private fun getFirstMessage(response: ChatCompletionResponse): String? =
response.choices.firstOrNull()?.message?.content

/** Clean the result message */
private fun cleanResultMessage(firstMessage: String): String = firstMessage
.replace(COMPLETED, "")
.replace(FAILED, "")
.trim()
}
2 changes: 1 addition & 1 deletion src/commonMain/kotlin/com/xebia/functional/auto/DSL.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ suspend fun <A> ai(
vectorStore
).invoke(Objective(augmentedPrompt))
require(result != null) { "No result found" }
return catch({ json.decodeFromString(deserializationStrategy, result.value) }) { e ->
return catch({ json.decodeFromString(deserializationStrategy, result.value()) }) { e ->
val fixJsonPrompt = """
|RESULT:
|$result
Expand Down
14 changes: 13 additions & 1 deletion src/commonMain/kotlin/com/xebia/functional/auto/Model.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ data class Task(val id: TaskId, val objective: Objective)
@Serializable
data class TaskWithResult(val task: Task, val result: TaskResult) {
fun toJson(): String = Json.encodeToString(this)
fun value(): String = result.value()
fun isCompleted(): Boolean = result.isCompleted()
fun isFailed(): Boolean = result.isFailed()

companion object {
fun fromJson(json: String): TaskWithResult =
Expand All @@ -35,4 +38,13 @@ data class TaskWithResult(val task: Task, val result: TaskResult) {

@JvmInline
@Serializable
value class TaskResult(val value: String)
value class TaskResult(private val value: String) {
fun value(): String =
value.replace(COMPLETED, "").replace(FAILED, "").trim()

fun isCompleted(): Boolean =
value.endsWith(COMPLETED)

fun isFailed(): Boolean =
value.endsWith(FAILED)
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private class KtorOpenAIClient(
configure(config.token, request)
}
}
// TODO error body fails to parse into ChatCompletionResponse
return response.body()
}

Expand Down
Loading