Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CombineDocsChain + VectorQAChain #21

Merged
merged 11 commits into from
May 3, 2023
9 changes: 9 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.ensureNotNull

interface Chain {

Expand Down Expand Up @@ -68,3 +70,10 @@ interface Chain {
ChainOutput.OnlyOutput -> outputs
}
}

fun Raise<Chain.InvalidInputs>.validateInput(inputs: Map<String, String>, inputKey: String): String =
ensureNotNull(inputs[inputKey]) {
Chain.InvalidInputs("The provided inputs: " +
inputs.keys.joinToString(", ") { "{$it}" } +
" do not match with chain's input: {$inputKey}")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.xebia.functional.chains

import arrow.core.Either
import com.xebia.functional.Document
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate

interface CombineDocsChain : Chain {
suspend fun combine(documents: List<Document>): Map<String, String>
}

@Suppress("LongParameterList")
suspend fun CombineDocsChain(
llm: OpenAIClient,
promptTemplate: PromptTemplate,
documents: List<Document>,
documentVariableName: String,
outputVariable: String,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): CombineDocsChain = object : CombineDocsChain {

private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet() - setOf(documentVariableName)
private val outputKeys: Set<String> = setOf("answer")

override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)

override suspend fun combine(documents: List<Document>): Map<String, String> {
val mergedDocs = documents.joinToString("\n") { it.content }
return mapOf(documentVariableName to mergedDocs)
}

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> {
val llmChain = LLMChain(
llm,
promptTemplate,
outputVariable = outputVariable,
chainOutput = chainOutput
)

val totalInputs = combine(documents) + inputs
return llmChain.run(totalInputs)
}
}
16 changes: 10 additions & 6 deletions src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ import com.xebia.functional.prompt.PromptTemplate
suspend fun LLMChain(
llm: OpenAIClient,
promptTemplate: PromptTemplate,
llmModel: String,
user: String,
echo: Boolean,
n: Int,
temperature: Double,
llmModel: String = "text-davinci-003",
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
outputVariable: String,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): Chain = object : Chain {

override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), chainOutput)
private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet()
private val outputKeys: Set<String> = setOf(outputVariable)

override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<Chain.InvalidInputs, Map<String, String>> =
either {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ package com.xebia.functional.chains
interface SequenceChain : Chain {
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
data class InvalidKeys(override val reason: String): Chain.Error(reason)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.ensureNotNull
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate

Expand Down Expand Up @@ -66,10 +65,3 @@ private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>)
Chain.InvalidInputs("The expected inputs are more than one: " +
inputKeys.joinToString(", ") { "{$it}" })
}

private fun Raise<Chain.InvalidInputs>.validateInput(inputs: Map<String, String>, inputKey: String): String =
ensureNotNull(inputs[inputKey]) {
Chain.InvalidInputs("The provided inputs: " +
inputs.keys.joinToString(", ") { "{$it}" } +
" do not match with chain's input: {$inputKey}")
}
66 changes: 66 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/chains/VectorQAChain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.recover
import com.xebia.functional.Document
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate
import com.xebia.functional.vectorstores.VectorStore

interface VectorQAChain : Chain {
suspend fun getDocs(question: String): List<Document>

data class InvalidTemplate(override val reason: String) : Chain.Error(reason)
}

@Suppress("LongParameterList")
suspend fun VectorQAChain(
llm: OpenAIClient,
vectorStore: VectorStore,
numOfDocs: Int,
outputVariable: String,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): VectorQAChain = object : VectorQAChain {

private val documentVariableName: String = "context"
private val inputVariable: String = "question"

override val config: Chain.Config = Chain.Config(setOf(inputVariable), setOf(outputVariable), chainOutput)

override suspend fun getDocs(question: String): List<Document> =
vectorStore.similaritySearch(question, numOfDocs)

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
either {
val promptTemplate = promptTemplate()

val question = validateInput(inputs, inputVariable)
val documents = getDocs(question)

val chain = CombineDocsChain(
llm,
promptTemplate,
documents,
documentVariableName,
outputVariable,
chainOutput
)

chain.run(inputs).bind()
}

private fun Raise<VectorQAChain.InvalidTemplate>.promptTemplate(): PromptTemplate =
recover({
val template = """
|Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
|{context}
|
|Question: {question}
|Helpful Answer:""".trimMargin()

PromptTemplate(template, listOf("context", "question"))
}) { raise(VectorQAChain.InvalidTemplate(it.reason)) }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.xebia.functional.chains

import com.xebia.functional.llm.openai.ChatCompletionRequest
import com.xebia.functional.llm.openai.ChatCompletionResponse
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.EmbeddingResult
import com.xebia.functional.llm.openai.OpenAIClient

val testLLM = object : OpenAIClient {
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
when (request.prompt) {
"Tell me a joke." ->
listOf(CompletionChoice("I'm not good at jokes", 1, "foo"))

"My name is foo and I'm 28 years old" ->
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, "foo"))

testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, "foo"))
testTemplateInputsFormatted -> listOf(CompletionChoice("Two inputs, right?", 1, "foo"))
testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, "foo"))
else -> listOf(CompletionChoice("foo", 1, "bar"))
}

override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
TODO()

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
TODO()
}

val testContext = """foo foo foo
|bar bar bar
|baz baz baz""".trimMargin()

val testContextOutput = mapOf("context" to testContext)

val testTemplate = """From the following context:
|
|{context}
|
|try to answer the following question: {question}""".trimMargin()

val testTemplateInputs = """From the following context:
|
|{context}
|
|I want to say: My name is {name} and I'm {age} years old""".trimMargin()

val testTemplateFormatted = """From the following context:
|
|foo foo foo
|bar bar bar
|baz baz baz
|
|try to answer the following question: What do you think?""".trimMargin()

val testTemplateInputsFormatted = """From the following context:
|
|foo foo foo
|bar bar bar
|baz baz baz
|
|I want to say: My name is Scala and I'm 28 years old""".trimMargin()

val testQATemplateFormatted = """
|Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
|foo foo foo
|bar bar bar
|baz baz baz
|
|Question: What do you think?
|Helpful Answer:""".trimMargin()

val testOutputIDK = mapOf("answer" to "I don't know")
val testOutputInputs = mapOf("answer" to "Two inputs, right?")
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.xebia.functional.chains

import arrow.core.raise.either
import com.xebia.functional.Document
import com.xebia.functional.prompt.PromptTemplate
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
import io.kotest.core.spec.style.StringSpec

class CombineDocsChainSpec : StringSpec({
val documentVariableName = "context"
val outputVariable = "answer"

"Combine should return all the documents properly combined" {
either {
val promptTemplate = PromptTemplate(testTemplate, listOf("context", "question"))
val docs = listOf(Document("foo foo foo"), Document("bar bar bar"), Document("baz baz baz"))
val chain = CombineDocsChain(testLLM, promptTemplate, docs, documentVariableName, outputVariable)
chain.combine(docs)
} shouldBeRight testContextOutput
}

"Run should return the proper LLMChain response with one input" {
either {
val promptTemplate = PromptTemplate(testTemplate, listOf("context", "question"))
val docs = listOf(Document("foo foo foo"), Document("bar bar bar"), Document("baz baz baz"))
val chain = CombineDocsChain(testLLM, promptTemplate, docs, documentVariableName, outputVariable)
chain.run("What do you think?").bind()
} shouldBeRight testOutputIDK
}

"Run should return the proper LLMChain response with more than one input" {
either {
val promptTemplate = PromptTemplate(testTemplateInputs, listOf("context", "name", "age"))
val docs = listOf(Document("foo foo foo"), Document("bar bar bar"), Document("baz baz baz"))
val chain = CombineDocsChain(
testLLM, promptTemplate, docs, documentVariableName, outputVariable, Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("name" to "Scala", "age" to "28")).bind()
} shouldBeRight testOutputInputs + mapOf("context" to testContext, "name" to "Scala", "age" to "28")
}

"Run should fail with a InvalidCombineDocumentsChainError if the inputs don't match the expected" {
either {
val promptTemplate = PromptTemplate(testTemplateInputs, listOf("context", "name", "age"))
val docs = listOf(Document("foo foo foo"), Document("bar bar bar"), Document("baz baz baz"))
val chain = CombineDocsChain(
testLLM, promptTemplate, docs, documentVariableName, outputVariable, Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("name" to "Scala", "city" to "Seattle")).bind()
} shouldBeLeft
Chain.InvalidInputs(
"The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}")
}
})
38 changes: 13 additions & 25 deletions src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.xebia.functional.chains

import arrow.core.raise.either
import com.xebia.functional.llm.openai.*
import com.xebia.functional.prompt.PromptTemplate
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
Expand All @@ -11,8 +10,8 @@ class LLMChainSpec : StringSpec({
"LLMChain should return a prediction with just the output" {
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
val promptTemplate = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(testLLM, promptTemplate, outputVariable = "answer")
chain.run("a joke").bind()
} shouldBeRight mapOf("answer" to "I'm not good at jokes")
}
Expand All @@ -21,7 +20,8 @@ class LLMChainSpec : StringSpec({
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
val chain = LLMChain(testLLM, prompt, outputVariable = "answer",
chainOutput = Chain.ChainOutput.InputAndOutput)
chain.run("a joke").bind()
} shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
}
Expand All @@ -30,7 +30,8 @@ class LLMChainSpec : StringSpec({
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
val chain = LLMChain(testLLM, prompt, outputVariable = "answer",
chainOutput = Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("age" to "28", "name" to "foo")).bind()
} shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
}
Expand All @@ -39,34 +40,21 @@ class LLMChainSpec : StringSpec({
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
val chain = LLMChain(testLLM, prompt, outputVariable = "answer",
chainOutput = Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("age" to "28", "brand" to "foo")).bind()
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
} shouldBeLeft
Chain.InvalidInputs(
"The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
}

"LLMChain should fail when using just one input but expecting more" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
val chain = LLMChain(testLLM, prompt, outputVariable = "answer",
chainOutput = Chain.ChainOutput.InputAndOutput)
chain.run("foo").bind()
} shouldBeLeft Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
}
})

val llm = object : OpenAIClient {
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
when(request.prompt) {
"Tell me a joke." ->
listOf(CompletionChoice("I'm not good at jokes", 1, "foo"))
"My name is foo and I'm 28 years old" ->
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, "foo"))
else -> listOf(CompletionChoice("foo", 1, "bar"))
}

override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse {
TODO("Not yet implemented")
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult = TODO()
}
Loading