-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: create BaseChain and Config model * feat: add first implementation of LLMChain * test: add ConfigSpec and adjust Config * test: add LLMChainSpec and adjust LLMChain * refactor: apply suggestion on LLMChain Co-authored-by: Simon Vergauwen <[email protected]> * refactor: rename BaseChain to Chain and change interface structure
- Loading branch information
1 parent
6ee648a
commit e8a4eca
Showing
4 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
62 changes: 62 additions & 0 deletions
62
src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package com.xebia.functional.chains | ||
|
||
import arrow.core.Either | ||
import arrow.core.raise.either | ||
import arrow.core.raise.ensure | ||
|
||
interface Chain { | ||
data class InvalidInputs(val reason: String) | ||
|
||
data class Config( | ||
val inputKeys: Set<String>, | ||
val outputKeys: Set<String>, | ||
val onlyOutputs: Boolean | ||
) { | ||
fun createInputs( | ||
inputs: String | ||
): Either<InvalidInputs, Map<String, String>> = | ||
either { | ||
ensure(inputKeys.size == 1) { | ||
InvalidInputs("The expected inputs are more than one: " + | ||
inputKeys.joinToString(", ") { "{$it}" }) | ||
} | ||
inputKeys.associateWith { inputs } | ||
} | ||
|
||
fun createInputs( | ||
inputs: Map<String, String> | ||
): Either<InvalidInputs, Map<String, String>> = | ||
either { | ||
ensure((inputKeys subtract inputs.keys).isEmpty()) { | ||
InvalidInputs("The provided inputs: " + | ||
inputs.keys.joinToString(", ") { "{$it}" } + | ||
" do not match with chain's inputs: " + | ||
inputKeys.joinToString(", ") { "{$it}" }) | ||
} | ||
inputs | ||
} | ||
} | ||
|
||
val config: Config | ||
|
||
suspend fun call(inputs: Map<String, String>): Map<String, String> | ||
|
||
suspend fun run(input: String): Either<InvalidInputs, Map<String, String>> = | ||
either { | ||
val preparedInputs = config.createInputs(input).bind() | ||
val result = call(preparedInputs) | ||
prepareOutputs(preparedInputs, result) | ||
} | ||
|
||
suspend fun run(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> = | ||
either { | ||
val preparedInputs = config.createInputs(inputs).bind() | ||
val result = call(preparedInputs) | ||
prepareOutputs(preparedInputs, result) | ||
} | ||
|
||
private fun prepareOutputs( | ||
inputs: Map<String, String>, outputs: Map<String, String> | ||
): Map<String, String> = | ||
if (config.onlyOutputs) outputs else inputs + outputs | ||
} |
42 changes: 42 additions & 0 deletions
42
src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package com.xebia.functional.chains | ||
|
||
import com.xebia.functional.llm.openai.CompletionChoice | ||
import com.xebia.functional.llm.openai.CompletionRequest | ||
import com.xebia.functional.llm.openai.OpenAIClient | ||
import com.xebia.functional.prompt.PromptTemplate | ||
|
||
@Suppress("LongParameterList") | ||
suspend fun LLMChain( | ||
llm: OpenAIClient, | ||
promptTemplate: PromptTemplate, | ||
llmModel: String, | ||
user: String, | ||
echo: Boolean, | ||
n: Int, | ||
temperature: Double, | ||
onlyOutputs: Boolean | ||
): Chain = object : Chain { | ||
|
||
override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), onlyOutputs) | ||
|
||
override suspend fun call(inputs: Map<String, String>): Map<String, String> { | ||
val prompt = promptTemplate.format(inputs) | ||
|
||
val request = CompletionRequest( | ||
model = llmModel, | ||
user = user, | ||
prompt = prompt, | ||
echo = echo, | ||
n = n, | ||
temperature = temperature, | ||
) | ||
|
||
val completions = llm.createCompletion(request) | ||
return formatOutput(completions) | ||
} | ||
|
||
private fun formatOutput(completions: List<CompletionChoice>): Map<String, String> = | ||
config.outputKeys.associateWith { | ||
completions.joinToString(", ") { it.text } | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package com.xebia.functional.chains | ||
|
||
import arrow.core.Either | ||
import io.kotest.core.spec.style.StringSpec | ||
import io.kotest.matchers.shouldBe | ||
|
||
class ConfigSpec : StringSpec({ | ||
|
||
"Chain Config should return the inputs properly" { | ||
val config = Chain.Config(setOf("name", "age"), setOf("text"), false) | ||
val result = config.createInputs(mapOf("name" to "foo", "age" to "bar")) | ||
result shouldBe Either.Right(mapOf("name" to "foo", "age" to "bar")) | ||
} | ||
|
||
"Chain Config should return the input as a Map" { | ||
val config = Chain.Config(setOf("input"), setOf("text"), false) | ||
val result = config.createInputs("foo") | ||
result shouldBe Either.Right(mapOf("input" to "foo")) | ||
} | ||
|
||
"Chain Config should fail when inputs set doesn't contain all inputKeys" { | ||
val config = Chain.Config(setOf("name", "age"), setOf("text"), false) | ||
val result = config.createInputs(mapOf("name" to "foo")) | ||
result shouldBe Either.Left( | ||
Chain.InvalidInputs("The provided inputs: {name} do not match with chain's inputs: {name}, {age}") | ||
) | ||
} | ||
|
||
"Chain Config should fail when inputs set has different inputKeys" { | ||
val config = Chain.Config(setOf("name", "age"), setOf("text"), false) | ||
val result = config.createInputs(mapOf("name" to "foo", "city" to "NY")) | ||
result shouldBe Either.Left( | ||
Chain.InvalidInputs("The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}") | ||
) | ||
} | ||
|
||
"Chain Config should fail when input is just one and expects more" { | ||
val config = Chain.Config(setOf("name", "age"), setOf("text"), false) | ||
val result = config.createInputs("foo") | ||
result shouldBe Either.Left( | ||
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") | ||
) | ||
} | ||
}) |
82 changes: 82 additions & 0 deletions
82
src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package com.xebia.functional.chains | ||
|
||
import arrow.core.Either | ||
import arrow.core.raise.either | ||
import com.xebia.functional.llm.openai.CompletionChoice | ||
import com.xebia.functional.llm.openai.CompletionRequest | ||
import com.xebia.functional.llm.openai.EmbeddingRequest | ||
import com.xebia.functional.llm.openai.EmbeddingResult | ||
import com.xebia.functional.llm.openai.OpenAIClient | ||
import com.xebia.functional.prompt.PromptTemplate | ||
import io.kotest.core.spec.style.StringSpec | ||
import io.kotest.matchers.shouldBe | ||
|
||
class LLMChainSpec : StringSpec({ | ||
"LLMChain should return a prediction with just the output" { | ||
val template = "Tell me {foo}." | ||
either { | ||
val prompt = PromptTemplate(template, listOf("foo")) | ||
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) | ||
chain.run("a joke").bind() | ||
} shouldBe Either.Right( | ||
mapOf("answer" to "I'm not good at jokes") | ||
) | ||
} | ||
|
||
"LLMChain should return a prediction with both output and inputs" { | ||
val template = "Tell me {foo}." | ||
either { | ||
val prompt = PromptTemplate(template, listOf("foo")) | ||
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) | ||
chain.run("a joke").bind() | ||
} shouldBe Either.Right( | ||
mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") | ||
) | ||
} | ||
|
||
"LLMChain should return a prediction with a more complex template" { | ||
val template = "My name is {name} and I'm {age} years old" | ||
either { | ||
val prompt = PromptTemplate(template, listOf("name", "age")) | ||
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) | ||
chain.run(mapOf("age" to "28", "name" to "foo")).bind() | ||
} shouldBe Either.Right( | ||
mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") | ||
) | ||
} | ||
|
||
"LLMChain should fail when inputs are not the expected ones from the PromptTemplate" { | ||
val template = "My name is {name} and I'm {age} years old" | ||
either { | ||
val prompt = PromptTemplate(template, listOf("name", "age")) | ||
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) | ||
chain.run(mapOf("age" to "28", "brand" to "foo")).bind() | ||
} shouldBe Either.Left( | ||
Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") | ||
) | ||
} | ||
|
||
"LLMChain should fail when using just one input but expecting more" { | ||
val template = "My name is {name} and I'm {age} years old" | ||
either { | ||
val prompt = PromptTemplate(template, listOf("name", "age")) | ||
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) | ||
chain.run("foo").bind() | ||
} shouldBe Either.Left( | ||
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") | ||
) | ||
} | ||
}) | ||
|
||
val llm = object : OpenAIClient { | ||
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> = | ||
when(request.prompt) { | ||
"Tell me a joke." -> | ||
listOf(CompletionChoice("I'm not good at jokes", 1, "foo")) | ||
"My name is foo and I'm 28 years old" -> | ||
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, "foo")) | ||
else -> listOf(CompletionChoice("foo", 1, "bar")) | ||
} | ||
|
||
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult = TODO() | ||
} |