From 0da833258b5f57c280d6691589b612f67aa8a9a6 Mon Sep 17 00:00:00 2001 From: Alejandro Serrano Date: Fri, 19 May 2023 17:01:07 +0200 Subject: [PATCH] No `Agent`, only `AI` (#78) Co-authored-by: Simon Vergauwen --- .../com/xebia/functional/xef/AIError.kt | 8 +- .../com/xebia/functional/xef/agents/Agent.kt | 68 -------- .../com/xebia/functional/xef/agents/Chains.kt | 37 ---- .../xef/agents/DeserializerLLMAgent.kt | 123 ------------- .../xef/agents/ImageGenerationAgent.kt | 56 ------ .../xebia/functional/xef/agents/LLMAgent.kt | 96 ---------- .../com/xebia/functional/xef/auto/AI.kt | 165 +----------------- .../xef/auto/DeserializerLLMAgent.kt | 154 ++++++++++++++++ .../xef/auto/ImageGenerationAgent.kt | 90 ++++++++++ .../com/xebia/functional/xef/auto/LLMAgent.kt | 91 ++++++++++ .../com/xebia/functional/xef/prompt/Prompt.kt | 23 +++ .../functional/xef/prompt/PromptTemplate.kt | 125 ++++++------- .../com/xebia/functional/xef/prompt/models.kt | 69 -------- .../functional/xef/chains/ChainTestUtils.kt | 119 ------------- .../functional/xef/chains/LLMAgentSpec.kt | 45 ----- .../xebia/functional/xef/prompt/ConfigSpec.kt | 61 ------- .../xef/prompt/PromptTemplateSpec.kt | 93 +++++----- .../xebia/functional/xef/agents/BingSearch.kt | 55 +++--- .../functional/xef/agents/DefaultSearch.kt | 12 +- .../functional/xef/agents/ScrapeUrlContent.kt | 7 +- .../xef/prompt/FilePromptTemplate.kt | 11 +- .../xef/prompt/PromptTemplateSpec.kt | 5 +- .../com/xebia/functional/xef/pdf/PDFLoader.kt | 13 +- .../scala/com/xebia/functional/auto/AI.scala | 11 +- 24 files changed, 517 insertions(+), 1020 deletions(-) delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Agent.kt delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Chains.kt delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/agents/DeserializerLLMAgent.kt delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/agents/ImageGenerationAgent.kt delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt delete mode 100644 core/src/commonTest/kotlin/com/xebia/functional/xef/chains/ChainTestUtils.kt delete mode 100644 core/src/commonTest/kotlin/com/xebia/functional/xef/chains/LLMAgentSpec.kt delete mode 100644 core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/ConfigSpec.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt index 06a834fb5..aad23d4ca 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt @@ -10,11 +10,6 @@ sealed interface AIError { get() = "No response from the AI" } - data class Combined(val errors: NonEmptyList) : AIError { - override val reason: String - get() = errors.joinToString { it.reason } - } - data class JsonParsing( val result: String, val maxAttempts: Int, @@ -29,11 +24,10 @@ sealed interface AIError { override val reason: String get() = "OpenAI Environment not found: ${errors.all.joinToString("\n")}" } + data class HuggingFace(val errors: NonEmptyList) : Env { override val reason: String get() = "HuggingFace Environment not found: ${errors.all.joinToString("\n")}" } } - - data class InvalidInputs(override val reason: String) : AIError } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Agent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Agent.kt deleted file mode 100644 index 73889f046..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Agent.kt +++ /dev/null @@ -1,68 +0,0 @@ -package com.xebia.functional.xef.agents - -import arrow.core.raise.Raise -import com.xebia.functional.xef.AIError - -interface Agent { - val name: String - val description: String - suspend fun Raise.call(input: Input): Output - - fun contramap(pretransform: (A) -> Input): Agent = - Wrapper(pretransform, this) { it } - - fun map(transform: (Output) -> B): Agent = Wrapper({ it }, this, transform) - - /** Record an [input] but don't execute the agent yet. */ - fun with(input: Input): ParameterlessAgent = - ParameterlessAgent(name = this.name, description = this.description) { - with(this@Agent) { call(input) } - } - - class Wrapper( - val pretransform: (A) -> B, - val agent: Agent, - val transform: (C) -> D - ) : Agent { - override val name = agent.name - override val description: String = agent.description - override suspend fun Raise.call(input: A): D { - val i = pretransform(input) - val o = with(agent) { call(i) } - return transform(o) - } - } - - companion object { - operator fun invoke( - name: String, - description: String, - action: suspend Raise.(Input) -> Output - ): Agent = - object : Agent { - override val name: String = name - override val description: String = description - override suspend fun Raise.call(input: Input): Output = action(input) - } - } -} - -interface ParameterlessAgent : Agent { - suspend fun Raise.call(): Output - override suspend fun Raise.call(input: Unit): Output = call() - - companion object { - operator fun invoke( - name: String, - description: String, - action: suspend Raise.() -> Output - ): ParameterlessAgent = - object : ParameterlessAgent { - override val name: String = name - override val description: String = description - override suspend fun Raise.call(): Output = action() - } - } -} - -typealias ContextualAgent = ParameterlessAgent> diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Chains.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Chains.kt deleted file mode 100644 index ffb293c59..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/Chains.kt +++ /dev/null @@ -1,37 +0,0 @@ -package com.xebia.functional.xef.agents - -import arrow.core.NonEmptyList -import arrow.core.raise.Raise -import arrow.core.raise.mapOrAccumulate -import arrow.core.raise.recover -import com.xebia.functional.xef.AIError - -// from https://docs.langchain.com/docs/components/chains/index_related_chains - -class MapReduceChain(val mapper: Agent, val reducer: Agent, R>) : - Agent, R> { - override val name = "MapReduce [${mapper.name}, ${reducer.name}]" - override val description: String = - "MapReduce [mapper = ${mapper.description}, reducer = ${reducer.description}]" - - override suspend fun Raise.call(input: List): R { - val mapResults = - recover({ mapOrAccumulate(input) { with(mapper) { call(it) } } }) { e: NonEmptyList - -> - raise(AIError.Combined(e)) - } - return with(reducer) { call(mapResults) } - } -} - -class RefineChain(val initial: Agent, val refiner: Agent, B>) : - Agent, B> { - override val name = "Refine [${initial.name}, ${refiner.name}]" - override val description: String = - "Refine [initial = ${initial.description}, refiner = ${refiner.description}]" - - override suspend fun Raise.call(input: NonEmptyList): B { - val initialResult = with(initial) { call(input.head) } - return input.tail.fold(initialResult) { acc, x -> with(refiner) { call(Pair(x, acc)) } } - } -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/DeserializerLLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/DeserializerLLMAgent.kt deleted file mode 100644 index 17cb4f490..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/DeserializerLLMAgent.kt +++ /dev/null @@ -1,123 +0,0 @@ -package com.xebia.functional.xef.agents - -import arrow.core.raise.Raise -import arrow.core.raise.catch -import arrow.core.raise.ensureNotNull -import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.auto.SerializationConfig -import com.xebia.functional.xef.auto.serialization.buildJsonSchema -import com.xebia.functional.xef.llm.openai.LLMModel -import com.xebia.functional.xef.llm.openai.OpenAIClient -import com.xebia.functional.xef.prompt.PromptTemplate -import com.xebia.functional.xef.prompt.append -import com.xebia.functional.xef.vectorstores.VectorStore -import kotlinx.serialization.KSerializer -import kotlinx.serialization.json.Json -import kotlinx.serialization.serializer - -class DeserializerLLMAgent( - serializer: KSerializer, - private val json: Json = Json { - ignoreUnknownKeys = true - isLenient = true - }, - private val maxDeserializationAttempts: Int = 5, - llm: OpenAIClient, - template: PromptTemplate, - model: LLMModel = LLMModel.GPT_3_5_TURBO, - context: VectorStore = VectorStore.EMPTY, - user: String = "testing", - echo: Boolean = false, - n: Int = 1, - temperature: Double = 0.0, - bringFromContext: Int = 10 -) : Agent, A> { - - companion object { - inline operator fun invoke( - llm: OpenAIClient, - template: PromptTemplate, - model: LLMModel = LLMModel.GPT_3_5_TURBO, - context: VectorStore = VectorStore.EMPTY, - user: String = "testing", - echo: Boolean = false, - n: Int = 1, - temperature: Double = 0.0, - bringFromContext: Int = 10, - json: Json = Json { - ignoreUnknownKeys = true - isLenient = true - }, - maxDeserializationAttempts: Int = 5, - ): DeserializerLLMAgent = - DeserializerLLMAgent( - serializer(), - json, - maxDeserializationAttempts, - llm, - template, - model, - context, - user, - echo, - n, - temperature, - bringFromContext - ) - } - - val serializationConfig: SerializationConfig = - SerializationConfig( - jsonSchema = buildJsonSchema(serializer.descriptor, false), - descriptor = serializer.descriptor, - deserializationStrategy = serializer - ) - - val responseInstructions = - """ - | - |Response Instructions: - |1. Return the entire response in a single line with not additional lines or characters. - |2. When returning the response consider values should be accordingly escaped so the json remains valid. - |3. Use the JSON schema to produce the result exclusively in valid JSON format. - |4. Pay attention to required vs non-required fields in the schema. - |JSON Schema: - |${serializationConfig.jsonSchema} - |Response: - """ - .trimMargin() - - val underlying: LLMAgent = - LLMAgent( - llm, - template.append(responseInstructions), - model, - context, - user, - echo, - n, - temperature, - bringFromContext - ) - - override val name: String = "Deserializer LLM Agent" - override val description: String = - "Runs a query through a LLM agent and deserializes the output from a JSON representation" - - override suspend fun Raise.call(input: Map): A { - var currentAttempts = 0 - while (currentAttempts < maxDeserializationAttempts) { - currentAttempts++ - val result = - ensureNotNull(with(underlying) { call(input) }.firstOrNull()) { AIError.NoResponse } - catch({ - return@call json.decodeFromString(serializationConfig.deserializationStrategy, result) - }) { e: IllegalArgumentException -> - if (currentAttempts == maxDeserializationAttempts) - raise(AIError.JsonParsing(result, maxDeserializationAttempts, e)) - // else continue with the next attempt - } - } - raise(AIError.NoResponse) - } -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/ImageGenerationAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/ImageGenerationAgent.kt deleted file mode 100644 index 88c39ded3..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/ImageGenerationAgent.kt +++ /dev/null @@ -1,56 +0,0 @@ -package com.xebia.functional.xef.agents - -import arrow.core.raise.Raise -import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest -import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse -import com.xebia.functional.xef.llm.openai.OpenAIClient -import com.xebia.functional.xef.prompt.PromptTemplate -import com.xebia.functional.xef.vectorstores.VectorStore - -class ImageGenerationAgent( - private val llm: OpenAIClient, - private val template: PromptTemplate, - private val context: VectorStore = VectorStore.EMPTY, - private val user: String = "testing", - private val numberImages: Int, - private val size: String, - private val bringFromContext: Int = 10 -) : Agent, ImagesGenerationResponse> { - - override val name = "Image Generation Agent" - override val description: String = "Generates images" - - override suspend fun Raise.call(input: Map): ImagesGenerationResponse { - val prompt = template.format(checkInput(template, input)) - - val ctxInfo = context.similaritySearch(prompt, bringFromContext) - val promptWithContext = - if (ctxInfo.isNotEmpty()) - """ - |Instructions: Use the [Information] below delimited by 3 backticks to accomplish - |the [Objective] at the end of the prompt. - |Try to match the data returned in the [Objective] with this [Information] as best as you can. - |[Information]: - |``` - |${ctxInfo.joinToString("\n")} - |``` - |$prompt - """ - .trimMargin() - else prompt - - return callImageGenerationEndpoint(promptWithContext) - } - - private suspend fun callImageGenerationEndpoint(prompt: String): ImagesGenerationResponse { - val request = - ImagesGenerationRequest( - prompt = prompt, - numberImages = numberImages, - size = size, - user = user - ) - return llm.createImages(request) - } -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt deleted file mode 100644 index 5aab588e7..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt +++ /dev/null @@ -1,96 +0,0 @@ -package com.xebia.functional.xef.agents - -import arrow.core.raise.Raise -import arrow.core.raise.ensure -import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.llm.openai.ChatCompletionRequest -import com.xebia.functional.xef.llm.openai.CompletionRequest -import com.xebia.functional.xef.llm.openai.LLMModel -import com.xebia.functional.xef.llm.openai.Message -import com.xebia.functional.xef.llm.openai.OpenAIClient -import com.xebia.functional.xef.llm.openai.Role -import com.xebia.functional.xef.prompt.PromptTemplate -import com.xebia.functional.xef.vectorstores.VectorStore - -class LLMAgent( - private val llm: OpenAIClient, - private val template: PromptTemplate, - private val model: LLMModel, - private val context: VectorStore = VectorStore.EMPTY, - private val user: String = "testing", - private val echo: Boolean = false, - private val n: Int = 1, - private val temperature: Double = 0.0, - private val bringFromContext: Int = 10 -) : Agent, List> { - - override val name = "LLM Agent" - override val description: String = "Runs a query through a LLM agent" - - override suspend fun Raise.call(input: Map): List { - val prompt = template.format(checkInput(template, input)) - - val ctxInfo = context.similaritySearch(prompt, bringFromContext) - val promptWithContext = - if (ctxInfo.isNotEmpty()) - """ - |Instructions: Use the [Information] below delimited by 3 backticks to accomplish - |the [Objective] at the end of the prompt. - |Try to match the data returned in the [Objective] with this [Information] as best as you can. - |[Information]: - |``` - |${ctxInfo.joinToString("\n")} - |``` - |$prompt - """ - .trimMargin() - else prompt - - return when (model.kind) { - LLMModel.Kind.Completion -> callCompletionEndpoint(promptWithContext) - LLMModel.Kind.Chat -> callChatEndpoint(promptWithContext) - } - } - - private suspend fun callCompletionEndpoint(prompt: String): List { - val request = - CompletionRequest( - model = model.name, - user = user, - prompt = prompt, - echo = echo, - n = n, - temperature = temperature, - maxTokens = 1024 - ) - return llm.createCompletion(request).map { it.text } - } - - private suspend fun callChatEndpoint(prompt: String): List { - val request = - ChatCompletionRequest( - model = model.name, - user = user, - messages = listOf(Message(Role.system.name, prompt)), - n = n, - temperature = temperature, - maxTokens = 1024 - ) - return llm.createChatCompletion(request).choices.map { it.message.content } - } -} - -fun Raise.checkInput( - template: PromptTemplate, - input: Map -): Map { - ensure((template.inputKeys subtract input.keys).isEmpty()) { - AIError.InvalidInputs( - "The provided inputs: " + - input.keys.joinToString(", ") { "{$it}" } + - " do not match with chain's inputs: " + - template.inputKeys.joinToString(", ") { "{$it}" } - ) - } - return input -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt index fdff39977..5d584a19f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt @@ -3,25 +3,17 @@ package com.xebia.functional.xef.auto import arrow.core.Either import arrow.core.left import arrow.core.raise.Raise -import arrow.core.raise.either import arrow.core.raise.recover import arrow.core.right import arrow.fx.coroutines.Resource import arrow.fx.coroutines.ResourceScope import arrow.fx.coroutines.resourceScope import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.agents.ContextualAgent -import com.xebia.functional.xef.agents.DeserializerLLMAgent -import com.xebia.functional.xef.agents.ImageGenerationAgent -import com.xebia.functional.xef.agents.LLMAgent import com.xebia.functional.xef.embeddings.Embeddings import com.xebia.functional.xef.embeddings.OpenAIEmbeddings import com.xebia.functional.xef.env.OpenAIConfig -import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse import com.xebia.functional.xef.llm.openai.KtorOpenAIClient -import com.xebia.functional.xef.llm.openai.LLMModel import com.xebia.functional.xef.llm.openai.OpenAIClient -import com.xebia.functional.xef.prompt.PromptTemplate import com.xebia.functional.xef.vectorstores.CombinedVectorStore import com.xebia.functional.xef.vectorstores.LocalVectorStore import com.xebia.functional.xef.vectorstores.LocalVectorStoreBuilder @@ -31,7 +23,6 @@ import io.github.oshai.KotlinLogging import kotlin.jvm.JvmName import kotlin.time.ExperimentalTime import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.json.JsonObject @@ -115,8 +106,8 @@ suspend inline fun AI.getOrThrow(): A = getOrElse { throw AIExcep * [AI] program, and [Raise] of [AIError] in case you want to compose any [Raise] based actions. */ class AIScope( - @PublishedApi internal val openAIClient: OpenAIClient, - @PublishedApi internal val context: VectorStore, + val openAIClient: OpenAIClient, + val context: VectorStore, internal val embeddings: Embeddings, private val logger: KLogger, resourceScope: ResourceScope, @@ -163,20 +154,6 @@ class AIScope( context.addTexts(docs.toList()) } - @AiDsl - suspend fun extendContext(vararg agents: ContextualAgent) { - agents.forEach { - logger.debug { "[${it.name}] Running" } - val docs = with(it) { call() } - if (docs.isNotEmpty()) { - context.addTexts(docs) - logger.debug { "[${it.name}] Found and memorized ${docs.size} docs" } - } else { - logger.debug { "[${it.name}] Found no docs" } - } - } - } - @AiDsl suspend fun contextScope( store: suspend (Embeddings) -> Resource, @@ -204,142 +181,4 @@ class AIScope( extendContext(*docs.toTypedArray()) scope(this) } - - /** Runs the [agent] to enlarge the [context], and then executes the [scope]. */ - @AiDsl - suspend fun contextScope(agent: ContextualAgent, scope: suspend AIScope.() -> A): A = - contextScope(listOf(agent), scope) - - /** Runs the [agents] to enlarge the [context], and then executes the [scope]. */ - @AiDsl - suspend fun contextScope( - agents: Collection, - scope: suspend AIScope.() -> A - ): A = contextScope { - extendContext(*agents.toTypedArray()) - scope(this) - } - - @AiDsl - suspend fun promptMessage( - question: String, - model: LLMModel = LLMModel.GPT_3_5_TURBO - ): List = promptMessage(PromptTemplate(question), emptyMap(), model) - - @AiDsl - suspend fun promptMessage( - prompt: PromptTemplate, - variables: Map, - model: LLMModel = LLMModel.GPT_3_5_TURBO - ): List = with(LLMAgent(openAIClient, prompt, model, context)) { call(variables) } - - /** - * Run a [question] describes the task you want to solve within the context of [AIScope]. Returns - * a value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable]. - * - * @throws SerializationException if serializer cannot be created (provided [A] or its type - * argument is not serializable). - * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection. - */ - @AiDsl - suspend inline fun prompt( - question: String, - model: LLMModel = LLMModel.GPT_3_5_TURBO - ): A = prompt(PromptTemplate(question), emptyMap(), model) - - /** - * Run a [prompt] describes the task you want to solve within the context of [AIScope]. Returns a - * value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable]. - * - * @throws SerializationException if serializer cannot be created (provided [A] or its type - * argument is not serializable). - * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection. - */ - @AiDsl - suspend inline fun prompt( - prompt: PromptTemplate, - variables: Map, - model: LLMModel = LLMModel.GPT_3_5_TURBO - ): A = with(DeserializerLLMAgent(openAIClient, prompt, model, context)) { call(variables) } - - /** - * Run a [prompt] describes the images you want to generate within the context of [AIScope]. - * Returns a [ImagesGenerationResponse] containing time and urls with images generated. - * - * @param prompt a [PromptTemplate] describing the images you want to generate. - * @param variables a map of variables to be replaced in the [prompt]. - * @param numberImages number of images to generate. - * @param size the size of the images to generate. - */ - @AiDsl - suspend fun images( - prompt: PromptTemplate, - variables: Map, - numberImages: Int = 1, - size: String = "1024x1024" - ): ImagesGenerationResponse = - with( - ImageGenerationAgent( - llm = openAIClient, - template = prompt, - context = context, - numberImages = numberImages, - size = size - ) - ) { - call(variables) - } - - /** - * Run a [prompt] describes the images you want to generate within the context of [AIScope]. - * Returns a [ImagesGenerationResponse] containing time and urls with images generated. - * - * @param prompt a [PromptTemplate] describing the images you want to generate. - * @param numberImages number of images to generate. - * @param size the size of the images to generate. - */ - @AiDsl - suspend fun images( - prompt: String, - numberImages: Int = 1, - size: String = "1024x1024" - ): ImagesGenerationResponse = images(PromptTemplate(prompt), emptyMap(), numberImages, size) - - /** - * Run a [prompt] describes the images you want to generate within the context of [AIScope]. - * Produces a [ImagesGenerationResponse] which then gets serialized to [A] through [prompt]. - * - * @param prompt a [PromptTemplate] describing the images you want to generate. - * @param size the size of the images to generate. - */ - @AiDsl - suspend inline fun Raise.image( - prompt: String, - size: String = "1024x1024", - llmModel: LLMModel = LLMModel.GPT_3_5_TURBO - ): A { - val imageResponse = images(prompt, 1, size) - val url = imageResponse.data.firstOrNull() ?: raise(AIError.NoResponse) - return either { - PromptTemplate( - """|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format - |specified at the end of the message. - |[URL]: - |``` - |{url} - |``` - |[PROMPT]: - |``` - |{prompt} - |``` - """ - .trimMargin(), - listOf("url", "prompt") - ) - } - .fold( - { raise(AIError.InvalidInputs(it.reason)) }, - { prompt(it, mapOf("url" to url.url, "prompt" to prompt), llmModel) } - ) - } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt new file mode 100644 index 000000000..f2596f922 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt @@ -0,0 +1,154 @@ +package com.xebia.functional.xef.auto + +import arrow.core.raise.catch +import arrow.core.raise.ensureNotNull +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.serialization.buildJsonSchema +import com.xebia.functional.xef.llm.openai.LLMModel +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.append +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.Json +import kotlinx.serialization.serializer + +/** + * Run a [question] describes the task you want to solve within the context of [AIScope]. Returns a + * value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable]. + * + * @throws SerializationException if serializer cannot be created (provided [A] or its type argument + * is not serializable). + * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection. + */ +@AiDsl +suspend inline fun AIScope.prompt( + question: String, + json: Json = Json { + ignoreUnknownKeys = true + isLenient = true + }, + maxDeserializationAttempts: Int = 5, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10 +): A = + prompt( + Prompt(question), + json, + maxDeserializationAttempts, + model, + user, + echo, + n, + temperature, + bringFromContext + ) + +/** + * Run a [prompt] describes the task you want to solve within the context of [AIScope]. Returns a + * value of [A] where [A] **has to be** annotated with [kotlinx.serialization.Serializable]. + * + * @throws SerializationException if serializer cannot be created (provided [A] or its type argument + * is not serializable). + * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection. + */ +@AiDsl +suspend inline fun AIScope.prompt( + prompt: Prompt, + json: Json = Json { + ignoreUnknownKeys = true + isLenient = true + }, + maxDeserializationAttempts: Int = 5, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10 +): A = + prompt( + prompt, + serializer(), + json, + maxDeserializationAttempts, + model, + user, + echo, + n, + temperature, + bringFromContext + ) + +@AiDsl +suspend fun AIScope.prompt( + prompt: Prompt, + serializer: KSerializer, + json: Json = Json { + ignoreUnknownKeys = true + isLenient = true + }, + maxDeserializationAttempts: Int = 5, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10 +): A { + val serializationConfig: SerializationConfig = + SerializationConfig( + jsonSchema = buildJsonSchema(serializer.descriptor, false), + descriptor = serializer.descriptor, + deserializationStrategy = serializer + ) + + val responseInstructions = + """ + | + |Response Instructions: + |1. Return the entire response in a single line with not additional lines or characters. + |2. When returning the response consider values should be accordingly escaped so the json remains valid. + |3. Use the JSON schema to produce the result exclusively in valid JSON format. + |4. Pay attention to required vs non-required fields in the schema. + |JSON Schema: + |${serializationConfig.jsonSchema} + |Response: + """ + .trimMargin() + + return tryDeserialize(serializationConfig, json, maxDeserializationAttempts) { + promptMessage( + prompt.append(responseInstructions), + model, + user, + echo, + n, + temperature, + bringFromContext + ) + } +} + +suspend fun AIScope.tryDeserialize( + serializationConfig: SerializationConfig, + json: Json, + maxDeserializationAttempts: Int, + agent: AI> +): A { + var currentAttempts = 0 + while (currentAttempts < maxDeserializationAttempts) { + currentAttempts++ + val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse } + catch({ json.decodeFromString(serializationConfig.deserializationStrategy, result) }) { + e: IllegalArgumentException -> + if (currentAttempts == maxDeserializationAttempts) + raise(AIError.JsonParsing(result, maxDeserializationAttempts, e)) + // else continue with the next attempt + } + } + raise(AIError.NoResponse) +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt new file mode 100644 index 000000000..df641b6fe --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt @@ -0,0 +1,90 @@ +package com.xebia.functional.xef.auto + +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest +import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse +import com.xebia.functional.xef.prompt.Prompt + +/** + * Run a [prompt] describes the images you want to generate within the context of [AIScope]. + * Produces a [ImagesGenerationResponse] which then gets serialized to [A] through [prompt]. + * + * @param prompt a [Prompt] describing the images you want to generate. + * @param size the size of the images to generate. + */ +suspend inline fun AIScope.image( + prompt: String, + user: String = "testing", + size: String = "1024x1024", + bringFromContext: Int = 10 +): A { + val imageResponse = images(prompt, user, 1, size, bringFromContext) + val url = imageResponse.data.firstOrNull() ?: raise(AIError.NoResponse) + return prompt( + """|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format + |specified at the end of the message. + |[URL]: + |``` + |$url + |``` + |[PROMPT]: + |``` + |$prompt + |```""" + .trimMargin() + ) +} + +/** + * Run a [prompt] describes the images you want to generate within the context of [AIScope]. Returns + * a [ImagesGenerationResponse] containing time and urls with images generated. + * + * @param prompt a [Prompt] describing the images you want to generate. + * @param numberImages number of images to generate. + * @param size the size of the images to generate. + */ +suspend fun AIScope.images( + prompt: String, + user: String = "testing", + numberImages: Int = 1, + size: String = "1024x1024", + bringFromContext: Int = 10 +): ImagesGenerationResponse = images(Prompt(prompt), user, numberImages, size, bringFromContext) + +/** + * Run a [prompt] describes the images you want to generate within the context of [AIScope]. Returns + * a [ImagesGenerationResponse] containing time and urls with images generated. + * + * @param prompt a [Prompt] describing the images you want to generate. + * @param numberImages number of images to generate. + * @param size the size of the images to generate. + */ +suspend fun AIScope.images( + prompt: Prompt, + user: String = "testing", + numberImages: Int = 1, + size: String = "1024x1024", + bringFromContext: Int = 10 +): ImagesGenerationResponse { + val ctxInfo = context.similaritySearch(prompt.message, bringFromContext) + val promptWithContext = + if (ctxInfo.isNotEmpty()) { + """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish + |the [Objective] at the end of the prompt. + |Try to match the data returned in the [Objective] with this [Information] as best as you can. + |[Information]: + |``` + |${ctxInfo.joinToString("\n")} + |``` + |$prompt""" + .trimMargin() + } else prompt.message + val request = + ImagesGenerationRequest( + prompt = promptWithContext, + numberImages = numberImages, + size = size, + user = user + ) + return openAIClient.createImages(request) +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt new file mode 100644 index 000000000..d5c7fb288 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt @@ -0,0 +1,91 @@ +package com.xebia.functional.xef.auto + +import com.xebia.functional.xef.llm.openai.ChatCompletionRequest +import com.xebia.functional.xef.llm.openai.CompletionRequest +import com.xebia.functional.xef.llm.openai.LLMModel +import com.xebia.functional.xef.llm.openai.Message +import com.xebia.functional.xef.llm.openai.Role +import com.xebia.functional.xef.prompt.Prompt + +@AiDsl +suspend fun AIScope.promptMessage( + question: String, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10 +): List = + promptMessage(Prompt(question), model, user, echo, n, temperature, bringFromContext) + +@AiDsl +suspend fun AIScope.promptMessage( + prompt: Prompt, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10 +): List { + val ctxInfo = context.similaritySearch(prompt.message, bringFromContext) + val promptWithContext = + if (ctxInfo.isNotEmpty()) { + """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish + |the [Objective] at the end of the prompt. + |Try to match the data returned in the [Objective] with this [Information] as best as you can. + |[Information]: + |``` + |${ctxInfo.joinToString("\n")} + |``` + |$prompt""" + .trimMargin() + } else prompt.message + + return when (model.kind) { + LLMModel.Kind.Completion -> + callCompletionEndpoint(promptWithContext, model, user, echo, n, temperature) + LLMModel.Kind.Chat -> callChatEndpoint(promptWithContext, model, user, n, temperature) + } +} + +private suspend fun AIScope.callCompletionEndpoint( + prompt: String, + model: LLMModel, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0 +): List { + val request = + CompletionRequest( + model = model.name, + user = user, + prompt = prompt, + echo = echo, + n = n, + temperature = temperature, + maxTokens = 1024 + ) + return openAIClient.createCompletion(request).map { it.text } +} + +private suspend fun AIScope.callChatEndpoint( + prompt: String, + model: LLMModel, + user: String = "testing", + n: Int = 1, + temperature: Double = 0.0 +): List { + val request = + ChatCompletionRequest( + model = model.name, + user = user, + messages = listOf(Message(Role.system.name, prompt)), + n = n, + temperature = temperature, + maxTokens = 1024 + ) + return openAIClient.createChatCompletion(request).choices.map { it.message.content } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt new file mode 100644 index 000000000..ccaa9ed4b --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/Prompt.kt @@ -0,0 +1,23 @@ +package com.xebia.functional.xef.prompt + +import kotlin.jvm.JvmInline + +fun Prompt(examples: List, suffix: String, prefix: String): Prompt = + Prompt( + """|$prefix + | + |${examples.joinToString(separator = "\n")} + | + |$suffix""" + .trimMargin() + ) + +@JvmInline value class Prompt(val message: String) + +fun Prompt.prepend(text: String) = Prompt(text + message) + +operator fun Prompt.plus(other: Prompt): Prompt = Prompt(message + other.message) + +operator fun Prompt.plus(text: String): Prompt = Prompt(message + text) + +fun Prompt.append(text: String) = this + text diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptTemplate.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptTemplate.kt index e720466d6..dd4ba1091 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptTemplate.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptTemplate.kt @@ -1,14 +1,23 @@ package com.xebia.functional.xef.prompt +import arrow.core.Either +import arrow.core.NonEmptyList +import arrow.core.fold import arrow.core.raise.Raise import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.recover +import arrow.core.raise.zipOrAccumulate +import kotlin.jvm.JvmInline + +data class InvalidTemplate(val reason: String) fun Raise.PromptTemplate( examples: List, suffix: String, variables: List, prefix: String -): PromptTemplate { +): PromptTemplate { val template = """|$prefix | @@ -16,76 +25,74 @@ fun Raise.PromptTemplate( | |$suffix""" .trimMargin() - return PromptTemplate(Config(template, variables)) + return PromptTemplate(template, variables) } fun Raise.PromptTemplate( template: String, - variables: List -): PromptTemplate = PromptTemplate(Config(template, variables)) - -fun PromptTemplate(template: String): PromptTemplate = - PromptTemplate(either { Config(template, emptyList()) }.getOrNull()!!) - -interface PromptTemplate { - val inputKeys: List + validate: List? = null +): PromptTemplate = PromptTemplate.either(template, validate).bind() - suspend fun format(variables: Map): A - - fun mapK(transform: (A) -> B): PromptTemplate = - object : PromptTemplate { - override val inputKeys: List = this@PromptTemplate.inputKeys - override suspend fun format(variables: Map): B = - transform(this@PromptTemplate.format(variables)) - } +@JvmInline +value class PromptTemplate private constructor(val template: String) { + fun format(variables: Map): Prompt = + Prompt(variables.fold(template) { acc, (key, value) -> acc.replace("{$key}", value) }) companion object { - - operator fun invoke(config: Config): PromptTemplate = - object : PromptTemplate { - override val inputKeys: List = config.inputVariables - - override suspend fun format(variables: Map): String { - val mergedArgs = mergePartialAndUserVariables(variables, config.inputVariables) - return when (config.templateFormat) { - TemplateFormat.FString -> { - val sortedArgs = mergedArgs.toList().sortedBy { it.first } - sortedArgs.fold(config.template) { acc, (k, v) -> acc.replace("{$k}", v) } + fun either( + template: String, + variables: List? = null + ): Either = + either { + val placeholders = placeholderValues(template) + recover, Unit>({ + zipOrAccumulate( + { + variables?.let { + validate(template, variables.toSet() - placeholders.toSet(), "unused") + } + }, + { + variables?.let { + validate(template, placeholders.toSet() - variables.toSet(), "missing") + } + }, + { validateDuplicated(template, placeholders) } + ) { _, _, _ -> } + }) { + raise(InvalidTemplate(it.joinToString(transform = InvalidTemplate::reason))) } - } - - private fun mergePartialAndUserVariables( - variables: Map, - inputVariables: List - ): Map = - inputVariables.fold(variables) { acc, k -> - if (!acc.containsKey(k)) acc + (k to "{$k}") else acc - } - } - - fun human(promptTemplate: PromptTemplate): PromptTemplate = - promptTemplate.mapK(::HumanMessage) - - fun ai(promptTemplate: PromptTemplate): PromptTemplate = - promptTemplate.mapK(::AIMessage) - - fun system(promptTemplate: PromptTemplate): PromptTemplate = - promptTemplate.mapK(::SystemMessage) - - fun chat(promptTemplate: PromptTemplate, role: String): PromptTemplate = - promptTemplate.mapK { ChatMessage(it, role) } + template + } // We need to map otherwise Raise constructor gets precedence + .map { PromptTemplate(template) } } } -fun PromptTemplate.prepend(text: String) = - object : PromptTemplate by this { - override suspend fun format(variables: Map): String = - text + this@prepend.format(variables) +private fun Raise.validate( + template: String, + diffSet: Set, + msg: String +): Unit = + ensure(diffSet.isEmpty()) { + InvalidTemplate( + "Template '$template' has $msg arguments: ${diffSet.joinToString(", ") { "{$it}" }}" + ) } -fun PromptTemplate.append(text: String) = - object : PromptTemplate by this { - override suspend fun format(variables: Map): String = - this@append.format(variables) + text +private fun Raise.validateDuplicated( + template: String, + placeholders: List +) { + val args = placeholders.groupBy { it }.filter { it.value.size > 1 }.keys + ensure(args.isEmpty()) { + InvalidTemplate( + "Template '$template' has duplicate arguments: ${args.joinToString(", ") { "{$it}" }}" + ) } +} + +private fun placeholderValues(template: String): List { + @Suppress("RegExpRedundantEscape") val regex = Regex("""\{([^\{\}]+)\}""") + return regex.findAll(template).toList().mapNotNull { it.groupValues.getOrNull(1) } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/models.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/models.kt index 4c1747926..9ade5c4cb 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/models.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/models.kt @@ -1,11 +1,5 @@ package com.xebia.functional.xef.prompt -import arrow.core.Either -import arrow.core.NonEmptyList -import arrow.core.raise.Raise -import arrow.core.raise.either -import arrow.core.raise.ensure -import arrow.core.raise.zipOrAccumulate import kotlinx.serialization.Serializable enum class Type { @@ -46,69 +40,6 @@ data class ChatMessage(override val content: String, val role: String) : Message override fun format(): String = "$role: $content" } -enum class TemplateFormat { - FString -} - -data class InvalidTemplate(val reason: String) - -fun Raise.Config(template: String, inputVariables: List): Config = - Config.either(template, inputVariables).bind() - -class Config -private constructor( - val inputVariables: List, - val template: String, - val templateFormat: TemplateFormat = TemplateFormat.FString -) { - companion object { - // We cannot define `operator fun invoke` with `Raise` without context receivers, - // so we define an intermediate `Either` based function. - // This is because adding `Raise` results in 2 receivers. - fun either(template: String, variables: List): Either = - either, Config> { - val placeholders = placeholderValues(template) - - zipOrAccumulate( - { validate(template, variables.toSet() - placeholders.toSet(), "unused") }, - { validate(template, placeholders.toSet() - variables.toSet(), "missing") }, - { validateDuplicated(template, placeholders) } - ) { _, _, _ -> - Config(variables, template) - } - } - .mapLeft { InvalidTemplate(it.joinToString(transform = InvalidTemplate::reason)) } - } -} - -private fun Raise.validate( - template: String, - diffSet: Set, - msg: String -): Unit = - ensure(diffSet.isEmpty()) { - InvalidTemplate( - "Template '$template' has $msg arguments: ${diffSet.joinToString(", ") { "{$it}" }}" - ) - } - -private fun Raise.validateDuplicated( - template: String, - placeholders: List -) { - val args = placeholders.groupBy { it }.filter { it.value.size > 1 }.keys - ensure(args.isEmpty()) { - InvalidTemplate( - "Template '$template' has duplicate arguments: ${args.joinToString(", ") { "{$it}" }}" - ) - } -} - -private fun placeholderValues(template: String): List { - @Suppress("RegExpRedundantEscape") val regex = Regex("""\{([^\{\}]+)\}""") - return regex.findAll(template).toList().mapNotNull { it.groupValues.getOrNull(1) } -} - private fun String.capitalized(): String = replaceFirstChar { if (it.isLowerCase()) it.titlecase() else it.toString() } diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/ChainTestUtils.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/ChainTestUtils.kt deleted file mode 100644 index feb8a7977..000000000 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/ChainTestUtils.kt +++ /dev/null @@ -1,119 +0,0 @@ -package com.xebia.functional.xef.chains - -import com.xebia.functional.xef.llm.openai.ChatCompletionRequest -import com.xebia.functional.xef.llm.openai.ChatCompletionResponse -import com.xebia.functional.xef.llm.openai.Choice -import com.xebia.functional.xef.llm.openai.CompletionChoice -import com.xebia.functional.xef.llm.openai.CompletionRequest -import com.xebia.functional.xef.llm.openai.EmbeddingRequest -import com.xebia.functional.xef.llm.openai.EmbeddingResult -import com.xebia.functional.xef.llm.openai.ImageGenerationUrl -import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest -import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse -import com.xebia.functional.xef.llm.openai.Message -import com.xebia.functional.xef.llm.openai.OpenAIClient -import com.xebia.functional.xef.llm.openai.Role -import com.xebia.functional.xef.llm.openai.Usage - -val testLLM = - object : OpenAIClient { - override suspend fun createCompletion(request: CompletionRequest): List = - when (request.prompt) { - "Tell me a joke." -> - listOf(CompletionChoice("I'm not good at jokes", 1, finishReason = "foo")) - "My name is foo and I'm 28 years old" -> - listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, finishReason = "foo")) - testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo")) - testTemplateInputsFormatted -> - listOf(CompletionChoice("Two inputs, right?", 1, finishReason = "foo")) - testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo")) - else -> listOf(CompletionChoice("foo", 1, finishReason = "bar")) - } - - override suspend fun createChatCompletion( - request: ChatCompletionRequest - ): ChatCompletionResponse = - when (request.messages.firstOrNull()?.content) { - "Tell me a joke." -> fakeChatCompletion("I'm not good at jokes") - "My name is foo and I'm 28 years old" -> - fakeChatCompletion("Hello there! Nice to meet you foo") - testTemplateFormatted -> fakeChatCompletion("I don't know") - testTemplateInputsFormatted -> fakeChatCompletion("Two inputs, right?") - testQATemplateFormatted -> fakeChatCompletion("I don't know") - else -> fakeChatCompletion("foo") - } - - override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse = - ImagesGenerationResponse(1, listOf(ImageGenerationUrl("foo"))) - - private fun fakeChatCompletion(message: String): ChatCompletionResponse = - ChatCompletionResponse( - id = "foo", - `object` = "foo", - created = 1, - model = "foo", - usage = Usage(1, 1, 1), - choices = listOf(Choice(Message(Role.system.name, message), "foo", index = 0)) - ) - - override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult = TODO() - } - -val testContext = - """foo foo foo - |bar bar bar - |baz baz baz""" - .trimMargin() - -val testContextOutput = mapOf("context" to testContext) - -val testTemplate = - """From the following context: - | - |{context} - | - |try to answer the following question: {question}""" - .trimMargin() - -val testTemplateInputs = - """From the following context: - | - |{context} - | - |I want to say: My name is {name} and I'm {age} years old""" - .trimMargin() - -val testTemplateFormatted = - """From the following context: - | - |foo foo foo - |bar bar bar - |baz baz baz - | - |try to answer the following question: What do you think?""" - .trimMargin() - -val testTemplateInputsFormatted = - """From the following context: - | - |foo foo foo - |bar bar bar - |baz baz baz - | - |I want to say: My name is Scala and I'm 28 years old""" - .trimMargin() - -val testQATemplateFormatted = - """ - |Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. - | - |foo foo foo - |bar bar bar - |baz baz baz - | - |Question: What do you think? - |Helpful Answer:""" - .trimMargin() - -val testOutputIDK = mapOf("answer" to "I don't know") -val testOutputInputs = mapOf("answer" to "Two inputs, right?") diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/LLMAgentSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/LLMAgentSpec.kt deleted file mode 100644 index 16d53dc44..000000000 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/chains/LLMAgentSpec.kt +++ /dev/null @@ -1,45 +0,0 @@ -package com.xebia.functional.xef.chains - -import arrow.core.raise.either -import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.agents.LLMAgent -import com.xebia.functional.xef.llm.openai.LLMModel -import com.xebia.functional.xef.prompt.PromptTemplate -import io.kotest.assertions.arrow.core.shouldBeLeft -import io.kotest.assertions.arrow.core.shouldBeRight -import io.kotest.core.spec.style.StringSpec - -class LLMAgentSpec : - StringSpec({ - val model = LLMModel.GPT_3_5_TURBO - - "LLMAgent should return a prediction with a simple template" { - val template = "Tell me {foo}." - either { - val promptTemplate = PromptTemplate(template, listOf("foo")) - val chain = LLMAgent(testLLM, promptTemplate, model) - with(chain) { call(mapOf("foo" to "a joke")) } - } shouldBeRight listOf("I'm not good at jokes") - } - - "LLMAgent should return a prediction with a more complex template" { - val template = "My name is {name} and I'm {age} years old" - either { - val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMAgent(testLLM, prompt, model) - with(chain) { call(mapOf("age" to "28", "name" to "foo")) } - } shouldBeRight listOf("Hello there! Nice to meet you foo") - } - - "LLMAgent should fail when inputs are not the expected ones from the PromptTemplate" { - val template = "My name is {name} and I'm {age} years old" - either { - val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMAgent(testLLM, prompt, model) - with(chain) { call(mapOf("age" to "28", "brand" to "foo")) } - } shouldBeLeft - AIError.InvalidInputs( - "The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}" - ) - } - }) diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/ConfigSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/ConfigSpec.kt deleted file mode 100644 index 1afef8b77..000000000 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/ConfigSpec.kt +++ /dev/null @@ -1,61 +0,0 @@ -package com.xebia.functional.xef.prompt - -import arrow.core.raise.either -import io.kotest.assertions.arrow.core.shouldBeLeft -import io.kotest.assertions.arrow.core.shouldBeRight -import io.kotest.core.spec.style.StringSpec -import io.kotest.matchers.shouldBe - -class ConfigSpec : - StringSpec({ - "should return a valid Config if the template and input variables are valid" { - val template = "Hello {name}, you are {age} years old." - val variables = listOf("name", "age") - - val config = either { Config(template, variables) }.shouldBeRight() - - config.inputVariables shouldBe variables - config.template shouldBe template - config.templateFormat shouldBe TemplateFormat.FString - } - - "should fail with a InvalidTemplateError if the template has missing arguments" { - val template = "Hello {name}, you are {age} years old." - val variables = listOf("name") - - either { Config(template, variables) } shouldBeLeft - InvalidTemplate( - "Template 'Hello {name}, you are {age} years old.' has missing arguments: {age}" - ) - } - - "should fail with a InvalidTemplateError if the template has unused arguments" { - val template = "Hello {name}, you are {age} years old." - val variables = listOf("name", "age", "unused") - - either { Config(template, variables) } shouldBeLeft - InvalidTemplate( - "Template 'Hello {name}, you are {age} years old.' has unused arguments: {unused}" - ) - } - - "should fail with a InvalidTemplateError if there are duplicate input variables" { - val template = "Hello {name}, you are {name} years old." - val variables = listOf("name") - - either { Config(template, variables) } shouldBeLeft - InvalidTemplate( - "Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}" - ) - } - - "should fail with a combination of InvalidTemplateErrors if there are multiple things wrong" { - val template = "Hello {name}, you are {name} years old." - val variables = listOf("name", "age") - val unused = "Template 'Hello {name}, you are {name} years old.' has unused arguments: {age}" - val duplicated = - "Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}" - - either { Config(template, variables) } shouldBeLeft InvalidTemplate("$unused, $duplicated") - } - }) diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt index 48b29e82a..88b3d92ea 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt @@ -10,36 +10,29 @@ class PromptTemplateSpec : "PromptTemplate(template, list) should fail if the template is not valid" { val template = "Tell me {foo}." - either { - val prompt = PromptTemplate(template, emptyList()) - prompt.format(mapOf("foo" to "bar")) - } shouldBeLeft InvalidTemplate("Template 'Tell me {foo}.' has missing arguments: {foo}") + either { PromptTemplate(template, emptyList()).format(mapOf("foo" to "bar")) } shouldBeLeft + InvalidTemplate("Template 'Tell me {foo}.' has missing arguments: {foo}") } "format with no input variables shouldn't have any effect" { val template = "Tell me a joke." - either { - val prompt = PromptTemplate(template, emptyList()) - prompt.format(emptyMap()) - } shouldBeRight "Tell me a joke." + either { PromptTemplate(template).format(emptyMap()) } shouldBeRight Prompt("Tell me a joke.") } "format should return the expected result with a given set of variables" { val template = "My name is {name} and I'm {age} years old" val variables = mapOf("name" to "John", "age" to "47") - either { - val config = Config(template, listOf("name", "age")) - PromptTemplate(config).format(variables) - } shouldBeRight "My name is John and I'm 47 years old" + either { PromptTemplate(template, listOf("name", "age")).format(variables) } shouldBeRight + Prompt("My name is John and I'm 47 years old") } "PromptTemplate(template, list) should return a PromptTemplate instance with the given template and input variables" { val template = "My name is {name} and I'm {age} years old" val variables = mapOf("name" to "Mary", "age" to "25") either { PromptTemplate(template, listOf("name", "age")).format(variables) } shouldBeRight - "My name is Mary and I'm 25 years old" + Prompt("My name is Mary and I'm 25 years old") } " PromptTemplate(examples, suffix, variables, prefix) should return a PromptTemplate instance with the given examples and input variables" { @@ -63,7 +56,8 @@ class PromptTemplateSpec : PromptTemplate(examples, suffix = suffix, variables = listOf("product"), prefix = prefix) .format(variables) } shouldBeRight - """ + Prompt( + """ |I want you to act as a naming consultant for new companies. |Here are some examples of good company names: | @@ -73,7 +67,8 @@ class PromptTemplateSpec : | |The name should be short, catchy and easy to remember. |What is a good name for a company that makes functional programming?""" - .trimMargin() + .trimMargin() + ) } "format should return the expected result for variables with functions" { @@ -82,54 +77,48 @@ class PromptTemplateSpec : val variables = mapOf("name" to "Charles", "age" to getAge()) - either { - val config = Config(template, listOf("name", "age")) - PromptTemplate(config).format(variables) - } shouldBeRight "My name is Charles and I'm ${getAge()} years old" + either { PromptTemplate(template, listOf("name", "age")).format(variables) } shouldBeRight + Prompt("My name is Charles and I'm ${getAge()} years old") } - "format for human should return a HumanMessage" { - val template = "My name is {name} and I'm {age} years old" - val variables: Map = mapOf("name" to "Charles", "age" to "21") + "should fail with a InvalidTemplateError if the template has missing arguments" { + val template = "Hello {name}, you are {age} years old." + val variables = listOf("name") - either { - val prompt: PromptTemplate = PromptTemplate(template, listOf("name", "age")) - val humanPrompt: PromptTemplate = PromptTemplate.human(prompt) - humanPrompt.format(variables) - } shouldBeRight HumanMessage("My name is Charles and I'm 21 years old") + either { PromptTemplate(template, variables) } shouldBeLeft + InvalidTemplate( + "Template 'Hello {name}, you are {age} years old.' has missing arguments: {age}" + ) } - "format for system should return a SystemMessage" { - val template = "{sounds}" - val variables: Map = mapOf("sounds" to "Beep bep") + "should fail with a InvalidTemplateError if the template has unused arguments" { + val template = "Hello {name}, you are {age} years old." + val variables = listOf("name", "age", "unused") - either { - val prompt: PromptTemplate = PromptTemplate(template, listOf("sounds")) - val systemPrompt: PromptTemplate = PromptTemplate.system(prompt) - systemPrompt.format(variables) - } shouldBeRight SystemMessage("Beep bep") + either { PromptTemplate(template, variables) } shouldBeLeft + InvalidTemplate( + "Template 'Hello {name}, you are {age} years old.' has unused arguments: {unused}" + ) } - "format for ai should return a AIMessage" { - val template = "Hi, I'm an {machine}" - val variables: Map = mapOf("machine" to "AI") + "should fail with a InvalidTemplateError if there are duplicate input variables" { + val template = "Hello {name}, you are {name} years old." + val variables = listOf("name") - either { - val prompt: PromptTemplate = PromptTemplate(template, listOf("machine")) - val aiPrompt: PromptTemplate = PromptTemplate.ai(prompt) - aiPrompt.format(variables) - } shouldBeRight AIMessage("Hi, I'm an AI") + either { PromptTemplate(template, variables) } shouldBeLeft + InvalidTemplate( + "Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}" + ) } - "format for chat should return a ChatMessage" { - val role = "Yoda" - val template = "Lost a {action}, master {name} has." - val variables: Map = mapOf("action" to "battle", "name" to "Obi-Wan") + "should fail with a combination of InvalidTemplateErrors if there are multiple things wrong" { + val template = "Hello {name}, you are {name} years old." + val variables = listOf("name", "age") + val unused = "Template 'Hello {name}, you are {name} years old.' has unused arguments: {age}" + val duplicated = + "Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}" - either { - val prompt: PromptTemplate = PromptTemplate(template, listOf("action", "name")) - val chatPrompt: PromptTemplate = PromptTemplate.chat(prompt, role) - chatPrompt.format(variables) - } shouldBeRight ChatMessage("Lost a battle, master Obi-Wan has.", "Yoda") + either { PromptTemplate(template, variables) } shouldBeLeft + InvalidTemplate("$unused, $duplicated") } }) diff --git a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/BingSearch.kt b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/BingSearch.kt index 56a91726f..4f60dd16a 100644 --- a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/BingSearch.kt +++ b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/BingSearch.kt @@ -10,39 +10,34 @@ import java.util.stream.Collectors import kotlin.jvm.optionals.toList import kotlinx.coroutines.Dispatchers -fun bingSearch( +suspend fun bingSearch( search: String, splitter: BaseTextSplitter, url: String = "https://www.bing.com/news/search?q=${search.encodeURLParameter()}&format=rss", maxLinks: Int = 10 -): ParameterlessAgent> = - ParameterlessAgent>( - name = "Bing Search", - description = "Searches Bing for $search", - ) { - val items: List = RssReader().read(url).collect(Collectors.toList()) - val links = items.map { it.link }.flatMap { it.toList() }.take(maxLinks) - val linkedDocs = - links - .parMap(Dispatchers.IO) { link -> - try { - with(scrapeUrlContent(link, splitter)) { call() } - } catch (e: Exception) { - // ignore errors when scrapping nested content due to certificates and other remote - // issues - emptyList() - } +): List { + val items: List = RssReader().read(url).collect(Collectors.toList()) + val links = items.map { it.link }.flatMap { it.toList() }.take(maxLinks) + val linkedDocs = + links + .parMap(Dispatchers.IO) { link -> + try { + scrapeUrlContent(link, splitter) + } catch (e: Exception) { + // ignore errors when scrapping nested content due to certificates and other remote issues + emptyList() } - .flatten() - val docs = - items.map { - """| - |${it.title} - |${it.description} - |${it.link} - |${it.pubDate} - """ - .trimMargin() } - splitter.splitDocuments(linkedDocs + docs) - } + .flatten() + val docs = + items.map { + """| + |${it.title} + |${it.description} + |${it.link} + |${it.pubDate} + """ + .trimMargin() + } + return splitter.splitDocuments(linkedDocs + docs) +} diff --git a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/DefaultSearch.kt b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/DefaultSearch.kt index 881154296..f8386eef5 100644 --- a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/DefaultSearch.kt +++ b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/DefaultSearch.kt @@ -3,10 +3,8 @@ package com.xebia.functional.xef.agents import com.xebia.functional.tokenizer.ModelType import com.xebia.functional.xef.textsplitters.TokenTextSplitter -fun search(vararg prompt: String): Collection>> = - prompt.map { - bingSearch( - search = it, - TokenTextSplitter(ModelType.GPT_3_5_TURBO, chunkSize = 100, chunkOverlap = 50) - ) - } +suspend fun search(prompt: String): List = + bingSearch( + search = prompt, + TokenTextSplitter(ModelType.GPT_3_5_TURBO, chunkSize = 100, chunkOverlap = 50) + ) diff --git a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/ScrapeUrlContent.kt b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/ScrapeUrlContent.kt index 7419094c0..5135ec528 100644 --- a/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/ScrapeUrlContent.kt +++ b/core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/ScrapeUrlContent.kt @@ -3,8 +3,5 @@ package com.xebia.functional.xef.agents import com.xebia.functional.xef.loaders.ScrapeURLTextLoader import com.xebia.functional.xef.textsplitters.BaseTextSplitter -fun scrapeUrlContent(url: String, splitter: BaseTextSplitter): ParameterlessAgent> = - ParameterlessAgent(name = "Scrape URL content", description = "Scrape the content of $url") { - val loader = ScrapeURLTextLoader(url) - loader.loadAndSplit(splitter) - } +suspend fun scrapeUrlContent(url: String, splitter: BaseTextSplitter): List = + ScrapeURLTextLoader(url).loadAndSplit(splitter) diff --git a/filesystem/src/commonMain/kotlin/com/xebia/functional/xef/prompt/FilePromptTemplate.kt b/filesystem/src/commonMain/kotlin/com/xebia/functional/xef/prompt/FilePromptTemplate.kt index 5eb64cd48..c6554d97c 100644 --- a/filesystem/src/commonMain/kotlin/com/xebia/functional/xef/prompt/FilePromptTemplate.kt +++ b/filesystem/src/commonMain/kotlin/com/xebia/functional/xef/prompt/FilePromptTemplate.kt @@ -8,13 +8,10 @@ import okio.Path /** * Creates a PromptTemplate based on a Path */ -suspend fun Raise.PromptTemplate( +fun Raise.PromptTemplate( path: Path, - variables: List, fileSystem: FileSystem = FileSystem.DEFAULT -): PromptTemplate = +): PromptTemplate = fileSystem.read(path) { - val template = readUtf8() - val config = Config(template, variables) - PromptTemplate(config) - } \ No newline at end of file + PromptTemplate.either(readUtf8()).bind() + } diff --git a/filesystem/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt b/filesystem/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt index a944cbb8a..a348d5c09 100644 --- a/filesystem/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt +++ b/filesystem/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptTemplateSpec.kt @@ -14,12 +14,11 @@ class PromptTemplateSpec : StringSpec({ val example = templates / "example.txt" write(example) { writeUtf8("My name is {name} and I'm {age} years old") } } - val inputVariables = listOf("name", "age") val variables = mapOf("name" to "Angela", "age" to "18") either { - val prompt = PromptTemplate("templates/example.txt".toPath(), inputVariables, fileSystem) + val prompt = PromptTemplate("templates/example.txt".toPath(), fileSystem) prompt.format(variables) - } shouldBeRight "My name is Angela and I'm 18 years old" + } shouldBeRight Prompt("My name is Angela and I'm 18 years old") } }) diff --git a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt index f7d9b8cab..fc0c97ba0 100644 --- a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt +++ b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt @@ -1,7 +1,6 @@ package com.xebia.functional.xef.pdf import com.xebia.functional.tokenizer.ModelType -import com.xebia.functional.xef.agents.ParameterlessAgent import com.xebia.functional.xef.loaders.BaseLoader import com.xebia.functional.xef.textsplitters.BaseTextSplitter import com.xebia.functional.xef.textsplitters.TokenTextSplitter @@ -9,15 +8,13 @@ import org.apache.pdfbox.pdmodel.PDDocument import org.apache.pdfbox.text.PDFTextStripper import java.io.File -fun pdf( +suspend fun pdf( file: File, splitter: BaseTextSplitter = TokenTextSplitter(modelType = ModelType.GPT_3_5_TURBO, chunkSize = 100, chunkOverlap = 50) -): ParameterlessAgent> = - ParameterlessAgent(name = "Get PDF content", description = "Get PDF Content of $file") { - val loader = PDFLoader(file) - - loader.loadAndSplit(splitter) - } +): List { + val loader = PDFLoader(file) + return loader.loadAndSplit(splitter) +} class PDFLoader(private val file: File) : BaseLoader { override suspend fun load(): List { diff --git a/scala/src/main/scala/com/xebia/functional/auto/AI.scala b/scala/src/main/scala/com/xebia/functional/auto/AI.scala index 7f84ad993..59a0fa29f 100644 --- a/scala/src/main/scala/com/xebia/functional/auto/AI.scala +++ b/scala/src/main/scala/com/xebia/functional/auto/AI.scala @@ -5,8 +5,6 @@ import com.xebia.functional.xef.auto.AIScope as KtAIScope import com.xebia.functional.xef.auto.AIException import com.xebia.functional.xef.auto.AIKt import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.agents.Agent as KtAgent -import com.xebia.functional.xef.agents.ParameterlessAgent import com.xebia.functional.xef.llm.openai.LLMModel //def example(using AIScope): String = @@ -32,9 +30,6 @@ object AI: end AI final case class AIScope(kt: KtAIScope): - def agent[A](agent: ParameterlessAgent[List[String]], scope: AIScope ?=> A): A = ??? - - def agent[A](agents: List[ParameterlessAgent[List[String]]], scope: AIScope ?=> A): A = ??? // TODO: Design signature for Scala3 w/ Json parser (with support for generating Json Schema)? def prompt[A]( @@ -43,5 +38,11 @@ final case class AIScope(kt: KtAIScope): llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO ): A = ??? + def promptMessage( + prompt: String, + maxAttempts: Int = 5, + llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO + ): String = ??? + private object AIScope: def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope)