From 3b6b7cf03f3da8d7e8e559d34b8f855fc4c5ef74 Mon Sep 17 00:00:00 2001 From: Simon Vergauwen Date: Tue, 9 May 2023 20:06:13 +0200 Subject: [PATCH 1/2] Fix agent scopes, and add documentation --- .../functional/langchain4k/chain/Weather.kt | 11 +- .../kotlin/com/xebia/functional/auto/AI.kt | 153 ++++++++++++------ .../kotlin/com/xebia/functional/auto/Agent.kt | 26 --- .../com/xebia/functional/tools/Agent.kt | 28 ++++ .../kotlin/com/xebia/functional/tools/Tool.kt | 13 -- .../com/xebia/functional/tool/BingSearch.kt | 6 +- .../xebia/functional/tool/DefaultSearch.kt | 4 +- .../xebia/functional/tool/ScrapeUrlContent.kt | 6 +- 8 files changed, 146 insertions(+), 101 deletions(-) delete mode 100644 kotlin/src/commonMain/kotlin/com/xebia/functional/auto/Agent.kt create mode 100644 kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Agent.kt delete mode 100644 kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Tool.kt diff --git a/example/src/main/kotlin/com/xebia/functional/langchain4k/chain/Weather.kt b/example/src/main/kotlin/com/xebia/functional/langchain4k/chain/Weather.kt index a9b28fdfe..56e357006 100644 --- a/example/src/main/kotlin/com/xebia/functional/langchain4k/chain/Weather.kt +++ b/example/src/main/kotlin/com/xebia/functional/langchain4k/chain/Weather.kt @@ -2,23 +2,18 @@ package com.xebia.functional.langchain4k.chain import arrow.core.getOrElse import arrow.core.raise.either -import arrow.core.raise.ensureNotNull import arrow.core.raise.recover import arrow.fx.coroutines.resourceScope -import com.xebia.functional.auto.Agent import com.xebia.functional.chains.VectorQAChain import com.xebia.functional.embeddings.OpenAIEmbeddings import com.xebia.functional.env.OpenAIConfig import com.xebia.functional.llm.openai.KtorOpenAIClient import com.xebia.functional.llm.openai.OpenAIClient import com.xebia.functional.tool.search +import com.xebia.functional.tools.storeResults import com.xebia.functional.vectorstores.LocalVectorStore import io.github.oshai.KLogger import io.github.oshai.KotlinLogging -import okio.Path -import okio.Path.Companion.toPath -import java.io.File -import java.net.URL import kotlin.time.ExperimentalTime data class WeatherExampleError(val reason: String) @@ -47,8 +42,8 @@ private suspend fun getQuestionAnswer( val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger) val vectorStore = LocalVectorStore(embeddings) - val tools = search("Weather in Cádiz, Spain") - Agent(tools).storeResults(vectorStore) + search("Weather in Cádiz, Spain") + .storeResults(vectorStore) val numOfDocs = 10 val outputVariable = "answer" diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt index e496248ea..d57efadce 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt @@ -6,7 +6,6 @@ import arrow.core.raise.Raise import arrow.core.raise.catch import arrow.core.raise.recover import arrow.core.right -import arrow.fx.coroutines.Atomic import arrow.fx.coroutines.ResourceScope import arrow.fx.coroutines.resourceScope import com.xebia.functional.AIError @@ -22,7 +21,8 @@ import com.xebia.functional.llm.openai.Message import com.xebia.functional.llm.openai.OpenAIClient import com.xebia.functional.llm.openai.Role import com.xebia.functional.logTruncated -import com.xebia.functional.tools.Tool +import com.xebia.functional.tools.Agent +import com.xebia.functional.tools.storeResults import com.xebia.functional.vectorstores.LocalVectorStore import com.xebia.functional.vectorstores.VectorStore import io.github.oshai.KLogger @@ -31,12 +31,13 @@ import kotlin.jvm.JvmName import kotlinx.serialization.serializer import kotlin.time.ExperimentalTime import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject @DslMarker -annotation class AIDSL +annotation class AiDsl data class SerializationConfig( val jsonSchema: JsonObject, @@ -44,41 +45,74 @@ data class SerializationConfig( val deserializationStrategy: DeserializationStrategy, ) -/* - * With context receivers this can become more generic, - * suspend context(Raise, ResourceScope, AIContext) () -> A +/** + * An [AI] value represents an action relying on artificial intelligence that can be run to produce an [A]. + * This value is _lazy_ and can be combined with other `AI` values using [AIScope.invoke], and thus forms a monadic DSL. + * + * All [AI] actions that are composed together using [AIScope.invoke] share the same [VectorStore], [OpenAIEmbeddings] and [OpenAIClient] instances. */ typealias AI = suspend AIScope.() -> A +/** A DSL block that makes it more convenient to construct [AI] values. */ inline fun ai(noinline block: suspend AIScope.() -> A): AI = block +/** + * Run the [AI] value to produce an [A], + * this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + */ @OptIn(ExperimentalTime::class) suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) -> A): A = - resourceScope { - recover({ + recover({ + resourceScope { val openAIConfig = OpenAIConfig() val openAiClient: OpenAIClient = KtorOpenAIClient(openAIConfig) val logger = KotlinLogging.logger("AutoAI") val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger) val vectorStore = LocalVectorStore(embeddings) - val scope = AIScope(openAiClient, vectorStore, Atomic(listOf()), logger, this@resourceScope, this) + val scope = AIScope(openAiClient, vectorStore, emptyArray(), logger, this, this@recover) invoke(scope) - }) { orElse(it) } - } + } + }) { orElse(it) } +/** + * Run the [AI] value to produce _either_ an [AIError], or [A]. + * this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + * @see getOrElse for an operator that allow directly handling the [AIError] case. + */ suspend inline fun AI.toEither(): Either = ai { invoke().right() }.getOrElse { it.left() } // TODO: Allow traced transformation of Raise errors class AIException(message: String) : RuntimeException(message) +/** + * Run the [AI] value to produce [A]. + * this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + * + * @see getOrElse for an operator that allow directly handling the [AIError] case instead of throwing. + * @throws AIException in case something went wrong. + */ suspend inline fun AI.getOrThrow(): A = getOrElse { throw AIException(it.reason) } +/** + * The [AIScope] is the context in which [AI] values are run. + * It encapsulates all the dependencies required to run [AI] values, + * and provides convenient syntax for writing [AI] based programs. + * + * It exposes the [ResourceScope] so you can easily add your own resources with the scope of the [AI] program, + * and [Raise] of [AIError] in case you want to compose any [Raise] based actions. + */ class AIScope( private val openAIClient: OpenAIClient, private val vectorStore: VectorStore, - private val agents: Atomic>, + private val agent: Array, private val logger: KLogger, resourceScope: ResourceScope, raise: Raise, @@ -88,12 +122,70 @@ class AIScope( }, ) : ResourceScope by resourceScope, Raise by raise { - @AIDSL + /** + * Allows invoking [AI] values in the context of this [AIScope]. + * + * ```kotlin + * data class CovidNews(val title: String, val content: String) + * val covidNewsToday = ai { + * val now = LocalDateTime.now() + * agent(search("$now covid-19 News")) { + * prompt("write a paragraph of about 300 words about the latest news on covid-19 on $now") + * } + * } + * + * data class BreakingNews(val title: String, val content: String, val date: String) + * + * fun AIScope.breakingNewsLastWeek(): List = + * agent(search("$date Breaking News")) { + * val now = LocalDateTime.now() + * (0..7).parMap { + * prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words") + * } + * } + * + * fun news(): AI> = ai { + * val covidNews = parZip( + * { covidNewsToday() }, + * { breakingNewsLastWeek() } + * ) { covidNews, breakingNews -> listOf(covidNews) + breakingNews } + * } + * ``` + */ + @AiDsl @JvmName("invokeAI") suspend operator fun AI.invoke(): A = invoke(this@AIScope) - @AIDSL + /** Creates a child scope of this [AIScope] with the specified [agent]. */ + @AiDsl + suspend fun agent(agent: Agent, scope: suspend AIScope.() -> A): A = + agent(arrayOf(agent), scope) + + /** Creates a child scope of this [AIScope] with the specified [agents]. */ + @AiDsl + suspend fun agent(agents: Array, scope: suspend AIScope.() -> A): A = + scope(AIScope(openAIClient, vectorStore, agents, logger, this, this)) + + /** + * Run a [prompt] describes the task you want to solve within the context of [AIScope], and any [agent] it contains. + * Returns a value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable]. + * + * @throws SerializationException if serializer cannot be created (provided [A] or its type argument is not serializable). + * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection + */ + @AiDsl + suspend inline fun prompt(prompt: String): A { + val serializer = serializer() + val serializationConfig: SerializationConfig = SerializationConfig( + jsonSchema = buildJsonSchema(serializer.descriptor, false), + descriptor = serializer.descriptor, + deserializationStrategy = serializer + ) + return prompt(prompt, serializationConfig) + } + + @AiDsl suspend fun prompt( prompt: String, serializationConfig: SerializationConfig, @@ -113,44 +205,13 @@ class AIScope( } } - @AIDSL - suspend fun agent(tool: Tool, scope: suspend () -> A): A { - val scopedAgent = Agent(listOf(tool)) - agents.update { it + scopedAgent } - val result = scope() - agents.update { it - scopedAgent } - return result - } - - @AIDSL - suspend fun agent(tool: Array, scope: suspend () -> A): A { - val scopedAgent = Agent(tool) - agents.update { it + scopedAgent } - val result = scope() - agents.update { it - scopedAgent } - return result - } - - @AIDSL - suspend inline fun prompt(prompt: String): A { - val serializer = serializer() - val serializationConfig: SerializationConfig = SerializationConfig( - jsonSchema = buildJsonSchema(serializer.descriptor, false), - descriptor = serializer.descriptor, - deserializationStrategy = serializer - ) - return prompt(prompt, serializationConfig) - } - private suspend fun Raise.openAIChatCall( question: String, promptWithContext: String, serializationConfig: SerializationConfig<*>, ): String { //run the agents so they store context in the database - agents.get().forEach { agent -> - agent.storeResults(vectorStore) - } + agent.storeResults(vectorStore) //run the vectorQAChain to get the answer val numOfDocs = 10 val outputVariable = "answer" diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/Agent.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/Agent.kt deleted file mode 100644 index cea084b2b..000000000 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/Agent.kt +++ /dev/null @@ -1,26 +0,0 @@ -package com.xebia.functional.auto - -import com.xebia.functional.tools.Tool -import com.xebia.functional.vectorstores.VectorStore -import io.github.oshai.KLogger -import io.github.oshai.KotlinLogging - -class Agent(private val tools: List) { - - constructor(tool: Array) : this(tool.toList()) - - val logger: KLogger = KotlinLogging.logger("Agent") - - suspend fun storeResults(vectorStore: VectorStore) { - tools.forEach { tool -> - logger.debug { "[${tool.name}] Running" } - val docs = tool.action(tool) - if (docs.isNotEmpty()) { - vectorStore.addDocuments(docs) - logger.debug { "[${tool.name}] Found and memorized ${docs.size} docs" } - } else { - logger.debug { "[${tool.name}] Found no docs" } - } - } - } -} diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Agent.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Agent.kt new file mode 100644 index 000000000..fa4c3bb33 --- /dev/null +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Agent.kt @@ -0,0 +1,28 @@ +package com.xebia.functional.tools + +import com.xebia.functional.Document +import com.xebia.functional.vectorstores.VectorStore +import io.github.oshai.KLogger +import io.github.oshai.KotlinLogging + +class Agent( + val name: String, + val description: String, + val action: suspend Agent.() -> List, +) { + val logger: KLogger by lazy { KotlinLogging.logger(name) } + + suspend fun storeResults(vectorStore: VectorStore) { + logger.debug { "[${name}] Running" } + val docs = action() + if (docs.isNotEmpty()) { + vectorStore.addDocuments(docs) + logger.debug { "[${name}] Found and memorized ${docs.size} docs" } + } else { + logger.debug { "[${name}] Found no docs" } + } + } +} + +suspend fun Array.storeResults(vectorStore: VectorStore) = + forEach { it.storeResults(vectorStore) } diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Tool.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Tool.kt deleted file mode 100644 index 514ec0b68..000000000 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Tool.kt +++ /dev/null @@ -1,13 +0,0 @@ -package com.xebia.functional.tools - -import com.xebia.functional.Document -import io.github.oshai.KLogger -import io.github.oshai.KotlinLogging - -data class Tool( - val name: String, - val description: String, - val action: suspend Tool.() -> List, -) { - val logger: KLogger by lazy { KotlinLogging.logger(name) } -} diff --git a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/BingSearch.kt b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/BingSearch.kt index 5bc2fe330..0856b2681 100644 --- a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/BingSearch.kt +++ b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/BingSearch.kt @@ -6,7 +6,7 @@ import com.apptasticsoftware.rssreader.Item import com.apptasticsoftware.rssreader.RssReader import com.xebia.functional.Document import com.xebia.functional.textsplitters.BaseTextSplitter -import com.xebia.functional.tools.Tool +import com.xebia.functional.tools.Agent import io.ktor.http.* import kotlinx.coroutines.Dispatchers import java.util.stream.Collectors @@ -17,8 +17,8 @@ fun bingSearch( splitter: BaseTextSplitter, url: String = "https://www.bing.com/news/search?q=${search.encodeURLParameter()}&format=rss", maxLinks: Int = 10 -): Tool = - Tool( +): Agent = + Agent( name = "Bing Search", description = "Searches Bing for $search", ) { diff --git a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/DefaultSearch.kt b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/DefaultSearch.kt index 180636cf7..91a8105f5 100644 --- a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/DefaultSearch.kt +++ b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/DefaultSearch.kt @@ -1,9 +1,9 @@ package com.xebia.functional.tool import com.xebia.functional.textsplitters.TokenTextSplitter -import com.xebia.functional.tools.Tool +import com.xebia.functional.tools.Agent -suspend fun search(vararg prompt: String): Array = +suspend fun search(vararg prompt: String): Array = prompt.map { bingSearch( search = it, diff --git a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/ScrapeUrlContent.kt b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/ScrapeUrlContent.kt index 15f8d1bac..e34221b0a 100644 --- a/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/ScrapeUrlContent.kt +++ b/kotlin/src/jvmMain/kotlin/com/xebia/functional/tool/ScrapeUrlContent.kt @@ -2,13 +2,13 @@ package com.xebia.functional.tool import com.xebia.functional.loaders.ScrapeURLTextLoader import com.xebia.functional.textsplitters.BaseTextSplitter -import com.xebia.functional.tools.Tool +import com.xebia.functional.tools.Agent fun scrapeUrlContent( url: String, splitter: BaseTextSplitter -): Tool = - Tool( +): Agent = + Agent( name = "Scrape URL content", description = "Scrape the content of $url" ) { From 40003bae98f36e276e20f5c4ddcf8e263eeb64fb Mon Sep 17 00:00:00 2001 From: Simon Vergauwen Date: Tue, 9 May 2023 20:26:02 +0200 Subject: [PATCH 2/2] Fix example --- .../kotlin/com/xebia/functional/auto/AI.kt | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt index d57efadce..7adb98427 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt @@ -136,13 +136,16 @@ class AIScope( * * data class BreakingNews(val title: String, val content: String, val date: String) * - * fun AIScope.breakingNewsLastWeek(): List = + * fun breakingNews(date: LocalDateTime): AI = ai { * agent(search("$date Breaking News")) { - * val now = LocalDateTime.now() - * (0..7).parMap { - * prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words") - * } + * prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words") * } + * } + * + * suspend fun AIScope.breakingNewsLastWeek(): List { + * val now = LocalDateTime.now() + * return (0..7).parMap { breakingNews(now.minusDays(it)).invoke() } + * } * * fun news(): AI> = ai { * val covidNews = parZip(