Skip to content

Commit

Permalink
Chain + LLMChain (#9)
Browse files Browse the repository at this point in the history
* feat: create BaseChain and Config model
* feat: add first implementation of LLMChain
* test: add ConfigSpec and adjust Config
* test: add LLMChainSpec and adjust LLMChain
* refactor: apply suggestion on LLMChain

Co-authored-by: Simon Vergauwen <[email protected]>
* refactor: rename BaseChain to Chain and change interface structure
  • Loading branch information
realdavidvega authored Apr 25, 2023
1 parent 6ee648a commit e8a4eca
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.either
import arrow.core.raise.ensure

interface Chain {
data class InvalidInputs(val reason: String)

data class Config(
val inputKeys: Set<String>,
val outputKeys: Set<String>,
val onlyOutputs: Boolean
) {
fun createInputs(
inputs: String
): Either<InvalidInputs, Map<String, String>> =
either {
ensure(inputKeys.size == 1) {
InvalidInputs("The expected inputs are more than one: " +
inputKeys.joinToString(", ") { "{$it}" })
}
inputKeys.associateWith { inputs }
}

fun createInputs(
inputs: Map<String, String>
): Either<InvalidInputs, Map<String, String>> =
either {
ensure((inputKeys subtract inputs.keys).isEmpty()) {
InvalidInputs("The provided inputs: " +
inputs.keys.joinToString(", ") { "{$it}" } +
" do not match with chain's inputs: " +
inputKeys.joinToString(", ") { "{$it}" })
}
inputs
}
}

val config: Config

suspend fun call(inputs: Map<String, String>): Map<String, String>

suspend fun run(input: String): Either<InvalidInputs, Map<String, String>> =
either {
val preparedInputs = config.createInputs(input).bind()
val result = call(preparedInputs)
prepareOutputs(preparedInputs, result)
}

suspend fun run(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> =
either {
val preparedInputs = config.createInputs(inputs).bind()
val result = call(preparedInputs)
prepareOutputs(preparedInputs, result)
}

private fun prepareOutputs(
inputs: Map<String, String>, outputs: Map<String, String>
): Map<String, String> =
if (config.onlyOutputs) outputs else inputs + outputs
}
42 changes: 42 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.xebia.functional.chains

import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate

@Suppress("LongParameterList")
suspend fun LLMChain(
llm: OpenAIClient,
promptTemplate: PromptTemplate,
llmModel: String,
user: String,
echo: Boolean,
n: Int,
temperature: Double,
onlyOutputs: Boolean
): Chain = object : Chain {

override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), onlyOutputs)

override suspend fun call(inputs: Map<String, String>): Map<String, String> {
val prompt = promptTemplate.format(inputs)

val request = CompletionRequest(
model = llmModel,
user = user,
prompt = prompt,
echo = echo,
n = n,
temperature = temperature,
)

val completions = llm.createCompletion(request)
return formatOutput(completions)
}

private fun formatOutput(completions: List<CompletionChoice>): Map<String, String> =
config.outputKeys.associateWith {
completions.joinToString(", ") { it.text }
}
}
44 changes: 44 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.xebia.functional.chains

import arrow.core.Either
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class ConfigSpec : StringSpec({

"Chain Config should return the inputs properly" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val result = config.createInputs(mapOf("name" to "foo", "age" to "bar"))
result shouldBe Either.Right(mapOf("name" to "foo", "age" to "bar"))
}

"Chain Config should return the input as a Map" {
val config = Chain.Config(setOf("input"), setOf("text"), false)
val result = config.createInputs("foo")
result shouldBe Either.Right(mapOf("input" to "foo"))
}

"Chain Config should fail when inputs set doesn't contain all inputKeys" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val result = config.createInputs(mapOf("name" to "foo"))
result shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {name} do not match with chain's inputs: {name}, {age}")
)
}

"Chain Config should fail when inputs set has different inputKeys" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val result = config.createInputs(mapOf("name" to "foo", "city" to "NY"))
result shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}")
)
}

"Chain Config should fail when input is just one and expects more" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val result = config.createInputs("foo")
result shouldBe Either.Left(
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
)
}
})
82 changes: 82 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.either
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.EmbeddingResult
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class LLMChainSpec : StringSpec({
"LLMChain should return a prediction with just the output" {
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true)
chain.run("a joke").bind()
} shouldBe Either.Right(
mapOf("answer" to "I'm not good at jokes")
)
}

"LLMChain should return a prediction with both output and inputs" {
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
chain.run("a joke").bind()
} shouldBe Either.Right(
mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
)
}

"LLMChain should return a prediction with a more complex template" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
chain.run(mapOf("age" to "28", "name" to "foo")).bind()
} shouldBe Either.Right(
mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
)
}

"LLMChain should fail when inputs are not the expected ones from the PromptTemplate" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
chain.run(mapOf("age" to "28", "brand" to "foo")).bind()
} shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
)
}

"LLMChain should fail when using just one input but expecting more" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
chain.run("foo").bind()
} shouldBe Either.Left(
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
)
}
})

val llm = object : OpenAIClient {
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
when(request.prompt) {
"Tell me a joke." ->
listOf(CompletionChoice("I'm not good at jokes", 1, "foo"))
"My name is foo and I'm 28 years old" ->
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, "foo"))
else -> listOf(CompletionChoice("foo", 1, "bar"))
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult = TODO()
}

0 comments on commit e8a4eca

Please sign in to comment.