Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent scopes, and add documentation #45

Merged
merged 2 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,18 @@ package com.xebia.functional.langchain4k.chain

import arrow.core.getOrElse
import arrow.core.raise.either
import arrow.core.raise.ensureNotNull
import arrow.core.raise.recover
import arrow.fx.coroutines.resourceScope
import com.xebia.functional.auto.Agent
import com.xebia.functional.chains.VectorQAChain
import com.xebia.functional.embeddings.OpenAIEmbeddings
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.KtorOpenAIClient
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.tool.search
import com.xebia.functional.tools.storeResults
import com.xebia.functional.vectorstores.LocalVectorStore
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging
import okio.Path
import okio.Path.Companion.toPath
import java.io.File
import java.net.URL
import kotlin.time.ExperimentalTime

data class WeatherExampleError(val reason: String)
Expand Down Expand Up @@ -47,8 +42,8 @@ private suspend fun getQuestionAnswer(
val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger)
val vectorStore = LocalVectorStore(embeddings)

val tools = search("Weather in Cádiz, Spain")
Agent(tools).storeResults(vectorStore)
search("Weather in Cádiz, Spain")
.storeResults(vectorStore)

val numOfDocs = 10
val outputVariable = "answer"
Expand Down
156 changes: 110 additions & 46 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import arrow.core.raise.Raise
import arrow.core.raise.catch
import arrow.core.raise.recover
import arrow.core.right
import arrow.fx.coroutines.Atomic
import arrow.fx.coroutines.ResourceScope
import arrow.fx.coroutines.resourceScope
import com.xebia.functional.AIError
Expand All @@ -22,7 +21,8 @@ import com.xebia.functional.llm.openai.Message
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.Role
import com.xebia.functional.logTruncated
import com.xebia.functional.tools.Tool
import com.xebia.functional.tools.Agent
import com.xebia.functional.tools.storeResults
import com.xebia.functional.vectorstores.LocalVectorStore
import com.xebia.functional.vectorstores.VectorStore
import io.github.oshai.KLogger
Expand All @@ -31,54 +31,88 @@ import kotlin.jvm.JvmName
import kotlinx.serialization.serializer
import kotlin.time.ExperimentalTime
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject

@DslMarker
annotation class AIDSL
annotation class AiDsl

data class SerializationConfig<A>(
val jsonSchema: JsonObject,
val descriptor: SerialDescriptor,
val deserializationStrategy: DeserializationStrategy<A>,
)

/*
* With context receivers this can become more generic,
* suspend context(Raise<AIError>, ResourceScope, AIContext) () -> A
/**
* An [AI] value represents an action relying on artificial intelligence that can be run to produce an [A].
* This value is _lazy_ and can be combined with other `AI` values using [AIScope.invoke], and thus forms a monadic DSL.
*
* All [AI] actions that are composed together using [AIScope.invoke] share the same [VectorStore], [OpenAIEmbeddings] and [OpenAIClient] instances.
*/
typealias AI<A> = suspend AIScope.() -> A

/** A DSL block that makes it more convenient to construct [AI] values. */
inline fun <A> ai(noinline block: suspend AIScope.() -> A): AI<A> = block

/**
* Run the [AI] value to produce an [A],
* this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources.
*
* This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
*/
@OptIn(ExperimentalTime::class)
suspend inline fun <reified A> AI<A>.getOrElse(crossinline orElse: suspend (AIError) -> A): A =
resourceScope {
recover({
recover({
resourceScope {
val openAIConfig = OpenAIConfig()
val openAiClient: OpenAIClient = KtorOpenAIClient(openAIConfig)
val logger = KotlinLogging.logger("AutoAI")
val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger)
val vectorStore = LocalVectorStore(embeddings)
val scope = AIScope(openAiClient, vectorStore, Atomic(listOf()), logger, this@resourceScope, this)
val scope = AIScope(openAiClient, vectorStore, emptyArray(), logger, this, this@recover)
invoke(scope)
}) { orElse(it) }
}
}
}) { orElse(it) }

/**
* Run the [AI] value to produce _either_ an [AIError], or [A].
* this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources.
*
* This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
* @see getOrElse for an operator that allow directly handling the [AIError] case.
*/
suspend inline fun <reified A> AI<A>.toEither(): Either<AIError, A> =
ai { invoke().right() }.getOrElse { it.left() }

// TODO: Allow traced transformation of Raise errors
class AIException(message: String) : RuntimeException(message)

/**
* Run the [AI] value to produce [A].
* this method initialises all the dependencies required to run the [AI] value and once it finishes it closes all the resources.
*
* This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
*
* @see getOrElse for an operator that allow directly handling the [AIError] case instead of throwing.
* @throws AIException in case something went wrong.
*/
suspend inline fun <reified A> AI<A>.getOrThrow(): A =
getOrElse { throw AIException(it.reason) }

/**
* The [AIScope] is the context in which [AI] values are run.
* It encapsulates all the dependencies required to run [AI] values,
* and provides convenient syntax for writing [AI] based programs.
*
* It exposes the [ResourceScope] so you can easily add your own resources with the scope of the [AI] program,
* and [Raise] of [AIError] in case you want to compose any [Raise] based actions.
*/
class AIScope(
private val openAIClient: OpenAIClient,
private val vectorStore: VectorStore,
private val agents: Atomic<List<Agent>>,
private val agent: Array<out Agent>,
private val logger: KLogger,
resourceScope: ResourceScope,
raise: Raise<AIError>,
Expand All @@ -88,12 +122,73 @@ class AIScope(
},
) : ResourceScope by resourceScope, Raise<AIError> by raise {

@AIDSL
/**
* Allows invoking [AI] values in the context of this [AIScope].
*
* ```kotlin
* data class CovidNews(val title: String, val content: String)
* val covidNewsToday = ai {
* val now = LocalDateTime.now()
* agent(search("$now covid-19 News")) {
* prompt<CovidNews>("write a paragraph of about 300 words about the latest news on covid-19 on $now")
* }
* }
*
* data class BreakingNews(val title: String, val content: String, val date: String)
*
* fun breakingNews(date: LocalDateTime): AI<BreakingNews> = ai {
* agent(search("$date Breaking News")) {
* prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words")
* }
* }
*
* suspend fun AIScope.breakingNewsLastWeek(): List<BreakingNews> {
* val now = LocalDateTime.now()
* return (0..7).parMap { breakingNews(now.minusDays(it)).invoke() }
* }
*
* fun news(): AI<List<News>> = ai {
* val covidNews = parZip(
* { covidNewsToday() },
* { breakingNewsLastWeek() }
* ) { covidNews, breakingNews -> listOf(covidNews) + breakingNews }
* }
* ```
*/
@AiDsl
@JvmName("invokeAI")
suspend operator fun <A> AI<A>.invoke(): A =
invoke(this@AIScope)

@AIDSL
/** Creates a child scope of this [AIScope] with the specified [agent]. */
@AiDsl
suspend fun <A> agent(agent: Agent, scope: suspend AIScope.() -> A): A =
agent(arrayOf(agent), scope)

/** Creates a child scope of this [AIScope] with the specified [agents]. */
@AiDsl
suspend fun <A> agent(agents: Array<out Agent>, scope: suspend AIScope.() -> A): A =
scope(AIScope(openAIClient, vectorStore, agents, logger, this, this))

/**
* Run a [prompt] describes the task you want to solve within the context of [AIScope], and any [agent] it contains.
* Returns a value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable].
*
* @throws SerializationException if serializer cannot be created (provided [A] or its type argument is not serializable).
* @throws IllegalArgumentException if any of [A]'s type arguments contains star projection
*/
@AiDsl
suspend inline fun <reified A> prompt(prompt: String): A {
val serializer = serializer<A>()
val serializationConfig: SerializationConfig<A> = SerializationConfig(
jsonSchema = buildJsonSchema(serializer.descriptor, false),
descriptor = serializer.descriptor,
deserializationStrategy = serializer
)
return prompt(prompt, serializationConfig)
}

@AiDsl
suspend fun <A> prompt(
prompt: String,
serializationConfig: SerializationConfig<A>,
Expand All @@ -113,44 +208,13 @@ class AIScope(
}
}

@AIDSL
suspend fun <A> agent(tool: Tool, scope: suspend () -> A): A {
val scopedAgent = Agent(listOf(tool))
agents.update { it + scopedAgent }
val result = scope()
agents.update { it - scopedAgent }
return result
}

@AIDSL
suspend fun <A> agent(tool: Array<out Tool>, scope: suspend () -> A): A {
val scopedAgent = Agent(tool)
agents.update { it + scopedAgent }
val result = scope()
agents.update { it - scopedAgent }
return result
}

@AIDSL
suspend inline fun <reified A> prompt(prompt: String): A {
val serializer = serializer<A>()
val serializationConfig: SerializationConfig<A> = SerializationConfig(
jsonSchema = buildJsonSchema(serializer.descriptor, false),
descriptor = serializer.descriptor,
deserializationStrategy = serializer
)
return prompt(prompt, serializationConfig)
}

private suspend fun Raise<AIError>.openAIChatCall(
question: String,
promptWithContext: String,
serializationConfig: SerializationConfig<*>,
): String {
//run the agents so they store context in the database
agents.get().forEach { agent ->
agent.storeResults(vectorStore)
}
agent.storeResults(vectorStore)
//run the vectorQAChain to get the answer
val numOfDocs = 10
val outputVariable = "answer"
Expand Down
26 changes: 0 additions & 26 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/auto/Agent.kt

This file was deleted.

28 changes: 28 additions & 0 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Agent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.xebia.functional.tools

import com.xebia.functional.Document
import com.xebia.functional.vectorstores.VectorStore
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging

class Agent(
val name: String,
val description: String,
val action: suspend Agent.() -> List<Document>,
) {
val logger: KLogger by lazy { KotlinLogging.logger(name) }

suspend fun storeResults(vectorStore: VectorStore) {
logger.debug { "[${name}] Running" }
val docs = action()
if (docs.isNotEmpty()) {
vectorStore.addDocuments(docs)
logger.debug { "[${name}] Found and memorized ${docs.size} docs" }
} else {
logger.debug { "[${name}] Found no docs" }
}
}
}

suspend fun Array<out Agent>.storeResults(vectorStore: VectorStore) =
forEach { it.storeResults(vectorStore) }
13 changes: 0 additions & 13 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/tools/Tool.kt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.apptasticsoftware.rssreader.Item
import com.apptasticsoftware.rssreader.RssReader
import com.xebia.functional.Document
import com.xebia.functional.textsplitters.BaseTextSplitter
import com.xebia.functional.tools.Tool
import com.xebia.functional.tools.Agent
import io.ktor.http.*
import kotlinx.coroutines.Dispatchers
import java.util.stream.Collectors
Expand All @@ -17,8 +17,8 @@ fun bingSearch(
splitter: BaseTextSplitter,
url: String = "https://www.bing.com/news/search?q=${search.encodeURLParameter()}&format=rss",
maxLinks: Int = 10
): Tool =
Tool(
): Agent =
Agent(
name = "Bing Search",
description = "Searches Bing for $search",
) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.xebia.functional.tool

import com.xebia.functional.textsplitters.TokenTextSplitter
import com.xebia.functional.tools.Tool
import com.xebia.functional.tools.Agent

suspend fun search(vararg prompt: String): Array<out Tool> =
suspend fun search(vararg prompt: String): Array<out Agent> =
prompt.map {
bingSearch(
search = it,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package com.xebia.functional.tool

import com.xebia.functional.loaders.ScrapeURLTextLoader
import com.xebia.functional.textsplitters.BaseTextSplitter
import com.xebia.functional.tools.Tool
import com.xebia.functional.tools.Agent

fun scrapeUrlContent(
url: String,
splitter: BaseTextSplitter
): Tool =
Tool(
): Agent =
Agent(
name = "Scrape URL content",
description = "Scrape the content of $url"
) {
Expand Down