Skip to content

Commit c0c170c

Browse files
Port Simple Sequence Chain + Unit tests (#12)
1 parent ce16f04 commit c0c170c

File tree

9 files changed

+210
-55
lines changed

9 files changed

+210
-55
lines changed

build.gradle.kts

+2-3
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ kotlin {
5151
sourceSets {
5252
commonMain {
5353
dependencies {
54-
implementation(libs.arrow.fx)
55-
implementation(libs.arrow.resilience)
56-
implementation(libs.kotlinx.serialization.json)
54+
implementation(libs.bundles.arrow)
5755
implementation(libs.bundles.ktor.client)
56+
implementation(libs.kotlinx.serialization.json)
5857
implementation(libs.okio)
5958
implementation(libs.uuid)
6059
implementation(libs.klogging)

gradle/libs.versions.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ hikari = "5.0.1"
1818
dokka = "1.8.10"
1919

2020
[libraries]
21-
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
21+
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
22+
arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
2223
arrow-resilience = { module = "io.arrow-kt:arrow-resilience", version.ref = "arrow" }
2324
open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
2425
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }
@@ -40,6 +41,11 @@ postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql"
4041
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }
4142

4243
[bundles]
44+
arrow = [
45+
"arrow-core",
46+
"arrow-fx-coroutines",
47+
"arrow-resilience"
48+
]
4349
ktor-client = [
4450
"ktor-client",
4551
"ktor-client-content-negotiation",

src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt

+17-9
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@ import arrow.core.raise.either
55
import arrow.core.raise.ensure
66

77
interface Chain {
8-
data class InvalidInputs(val reason: String)
8+
9+
enum class ChainOutput { InputAndOutput, OnlyOutput }
10+
11+
sealed class Error(open val reason: String)
12+
13+
data class InvalidInputs(override val reason: String): Error(reason)
914

1015
data class Config(
1116
val inputKeys: Set<String>,
1217
val outputKeys: Set<String>,
13-
val onlyOutputs: Boolean
18+
val chainOutput: ChainOutput = ChainOutput.OnlyOutput
1419
) {
1520
fun createInputs(
1621
inputs: String
@@ -28,7 +33,7 @@ interface Chain {
2833
): Either<InvalidInputs, Map<String, String>> =
2934
either {
3035
ensure((inputKeys subtract inputs.keys).isEmpty()) {
31-
InvalidInputs("The provided inputs: " +
36+
InvalidInputs("The provided inputs: " +
3237
inputs.keys.joinToString(", ") { "{$it}" } +
3338
" do not match with chain's inputs: " +
3439
inputKeys.joinToString(", ") { "{$it}" })
@@ -39,24 +44,27 @@ interface Chain {
3944

4045
val config: Config
4146

42-
suspend fun call(inputs: Map<String, String>): Map<String, String>
47+
suspend fun call(inputs: Map<String, String>): Either<Error, Map<String, String>>
4348

44-
suspend fun run(input: String): Either<InvalidInputs, Map<String, String>> =
49+
suspend fun run(input: String): Either<Error, Map<String, String>> =
4550
either {
4651
val preparedInputs = config.createInputs(input).bind()
47-
val result = call(preparedInputs)
52+
val result = call(preparedInputs).bind()
4853
prepareOutputs(preparedInputs, result)
4954
}
5055

51-
suspend fun run(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> =
56+
suspend fun run(inputs: Map<String, String>): Either<Error, Map<String, String>> =
5257
either {
5358
val preparedInputs = config.createInputs(inputs).bind()
54-
val result = call(preparedInputs)
59+
val result = call(preparedInputs).bind()
5560
prepareOutputs(preparedInputs, result)
5661
}
5762

5863
private fun prepareOutputs(
5964
inputs: Map<String, String>, outputs: Map<String, String>
6065
): Map<String, String> =
61-
if (config.onlyOutputs) outputs else inputs + outputs
66+
when (config.chainOutput) {
67+
ChainOutput.InputAndOutput -> inputs + outputs
68+
ChainOutput.OnlyOutput -> outputs
69+
}
6270
}

src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt

+18-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.xebia.functional.chains
22

3+
import arrow.core.Either
4+
import arrow.core.raise.either
35
import com.xebia.functional.llm.openai.CompletionChoice
46
import com.xebia.functional.llm.openai.CompletionRequest
57
import com.xebia.functional.llm.openai.OpenAIClient
@@ -14,26 +16,27 @@ suspend fun LLMChain(
1416
echo: Boolean,
1517
n: Int,
1618
temperature: Double,
17-
onlyOutputs: Boolean
19+
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
1820
): Chain = object : Chain {
1921

20-
override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), onlyOutputs)
22+
override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), chainOutput)
2123

22-
override suspend fun call(inputs: Map<String, String>): Map<String, String> {
23-
val prompt = promptTemplate.format(inputs)
24+
override suspend fun call(inputs: Map<String, String>): Either<Chain.InvalidInputs, Map<String, String>> =
25+
either {
26+
val prompt = promptTemplate.format(inputs)
2427

25-
val request = CompletionRequest(
26-
model = llmModel,
27-
user = user,
28-
prompt = prompt,
29-
echo = echo,
30-
n = n,
31-
temperature = temperature,
32-
)
28+
val request = CompletionRequest(
29+
model = llmModel,
30+
user = user,
31+
prompt = prompt,
32+
echo = echo,
33+
n = n,
34+
temperature = temperature,
35+
)
3336

34-
val completions = llm.createCompletion(request)
35-
return formatOutput(completions)
36-
}
37+
val completions = llm.createCompletion(request)
38+
formatOutput(completions)
39+
}
3740

3841
private fun formatOutput(completions: List<CompletionChoice>): Map<String, String> =
3942
config.outputKeys.associateWith {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.xebia.functional.chains
2+
3+
interface SequenceChain : Chain {
4+
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
5+
data class InvalidKeys(override val reason: String): Chain.Error(reason)
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package com.xebia.functional.chains
2+
3+
import arrow.core.Either
4+
import arrow.core.raise.Raise
5+
import arrow.core.raise.either
6+
import arrow.core.raise.ensure
7+
import arrow.core.raise.ensureNotNull
8+
import arrow.core.raise.recover
9+
import arrow.core.raise.zipOrAccumulate
10+
11+
fun Raise<Chain.Error>.SimpleSequenceChain(
12+
chains: List<Chain>,
13+
inputKey: String = "input",
14+
outputKey: String = "output",
15+
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
16+
): SimpleSequenceChain =
17+
SimpleSequenceChain.either(chains, inputKey, outputKey, chainOutput).bind()
18+
19+
class SimpleSequenceChain private constructor(
20+
private val chains: List<Chain>,
21+
private val inputKey: String,
22+
private val outputKey: String,
23+
chainOutput: Chain.ChainOutput
24+
) : SequenceChain {
25+
26+
override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)
27+
28+
override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
29+
either {
30+
val input = validateInput(inputs, inputKey)
31+
val firstRes = chains.first().run(input).bind()
32+
val chainRes = chains.drop(1).fold(firstRes) { acc, chain ->
33+
chain.run(acc).bind()
34+
}.values.first()
35+
mapOf(outputKey to chainRes)
36+
}
37+
38+
companion object {
39+
fun either(
40+
chains: List<Chain>,
41+
inputKey: String,
42+
outputKey: String,
43+
chainOutput: Chain.ChainOutput
44+
): Either<SequenceChain.InvalidKeys, SimpleSequenceChain> =
45+
either {
46+
val mappedChains: List<Chain> = chains.map { chain ->
47+
recover({
48+
zipOrAccumulate(
49+
{ validateInputKeys(chain.config.inputKeys) },
50+
{ validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain }
51+
}) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) }
52+
}
53+
SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput)
54+
}
55+
}
56+
}
57+
58+
private fun Raise<SequenceChain.InvalidOutputs>.validateOutputKeys(outputKeys: Set<String>): Unit =
59+
ensure(outputKeys.size == 1) {
60+
SequenceChain.InvalidOutputs("The expected outputs are more than one: " +
61+
outputKeys.joinToString(", ") { "{$it}" })
62+
}
63+
64+
private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>): Unit =
65+
ensure(inputKeys.size == 1) {
66+
Chain.InvalidInputs("The expected inputs are more than one: " +
67+
inputKeys.joinToString(", ") { "{$it}" })
68+
}
69+
70+
private fun Raise<Chain.InvalidInputs>.validateInput(inputs: Map<String, String>, inputKey: String): String =
71+
ensureNotNull(inputs[inputKey]) {
72+
Chain.InvalidInputs("The provided inputs: " +
73+
inputs.keys.joinToString(", ") { "{$it}" } +
74+
" do not match with chain's input: {$inputKey}")
75+
}

src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,35 @@ import io.kotest.matchers.shouldBe
77
class ConfigSpec : StringSpec({
88

99
"Chain Config should return the inputs properly" {
10-
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
10+
val config = Chain.Config(setOf("name", "age"), setOf("text"))
1111
val result = config.createInputs(mapOf("name" to "foo", "age" to "bar"))
1212
result shouldBe Either.Right(mapOf("name" to "foo", "age" to "bar"))
1313
}
1414

1515
"Chain Config should return the input as a Map" {
16-
val config = Chain.Config(setOf("input"), setOf("text"), false)
16+
val config = Chain.Config(setOf("input"), setOf("text"))
1717
val result = config.createInputs("foo")
1818
result shouldBe Either.Right(mapOf("input" to "foo"))
1919
}
2020

2121
"Chain Config should fail when inputs set doesn't contain all inputKeys" {
22-
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
22+
val config = Chain.Config(setOf("name", "age"), setOf("text"))
2323
val result = config.createInputs(mapOf("name" to "foo"))
2424
result shouldBe Either.Left(
2525
Chain.InvalidInputs("The provided inputs: {name} do not match with chain's inputs: {name}, {age}")
2626
)
2727
}
2828

2929
"Chain Config should fail when inputs set has different inputKeys" {
30-
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
30+
val config = Chain.Config(setOf("name", "age"), setOf("text"))
3131
val result = config.createInputs(mapOf("name" to "foo", "city" to "NY"))
3232
result shouldBe Either.Left(
3333
Chain.InvalidInputs("The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}")
3434
)
3535
}
3636

3737
"Chain Config should fail when input is just one and expects more" {
38-
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
38+
val config = Chain.Config(setOf("name", "age"), setOf("text"))
3939
val result = config.createInputs("foo")
4040
result shouldBe Either.Left(
4141
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")

src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt

+12-22
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,60 @@
11
package com.xebia.functional.chains
22

3-
import arrow.core.Either
43
import arrow.core.raise.either
54
import com.xebia.functional.llm.openai.CompletionChoice
65
import com.xebia.functional.llm.openai.CompletionRequest
76
import com.xebia.functional.llm.openai.EmbeddingRequest
87
import com.xebia.functional.llm.openai.EmbeddingResult
98
import com.xebia.functional.llm.openai.OpenAIClient
109
import com.xebia.functional.prompt.PromptTemplate
10+
import io.kotest.assertions.arrow.core.shouldBeLeft
11+
import io.kotest.assertions.arrow.core.shouldBeRight
1112
import io.kotest.core.spec.style.StringSpec
12-
import io.kotest.matchers.shouldBe
1313

1414
class LLMChainSpec : StringSpec({
1515
"LLMChain should return a prediction with just the output" {
1616
val template = "Tell me {foo}."
1717
either {
1818
val prompt = PromptTemplate(template, listOf("foo"))
19-
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true)
19+
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
2020
chain.run("a joke").bind()
21-
} shouldBe Either.Right(
22-
mapOf("answer" to "I'm not good at jokes")
23-
)
21+
} shouldBeRight mapOf("answer" to "I'm not good at jokes")
2422
}
2523

2624
"LLMChain should return a prediction with both output and inputs" {
2725
val template = "Tell me {foo}."
2826
either {
2927
val prompt = PromptTemplate(template, listOf("foo"))
30-
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
28+
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
3129
chain.run("a joke").bind()
32-
} shouldBe Either.Right(
33-
mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
34-
)
30+
} shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
3531
}
3632

3733
"LLMChain should return a prediction with a more complex template" {
3834
val template = "My name is {name} and I'm {age} years old"
3935
either {
4036
val prompt = PromptTemplate(template, listOf("name", "age"))
41-
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
37+
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
4238
chain.run(mapOf("age" to "28", "name" to "foo")).bind()
43-
} shouldBe Either.Right(
44-
mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
45-
)
39+
} shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
4640
}
4741

4842
"LLMChain should fail when inputs are not the expected ones from the PromptTemplate" {
4943
val template = "My name is {name} and I'm {age} years old"
5044
either {
5145
val prompt = PromptTemplate(template, listOf("name", "age"))
52-
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
46+
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
5347
chain.run(mapOf("age" to "28", "brand" to "foo")).bind()
54-
} shouldBe Either.Left(
55-
Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
56-
)
48+
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
5749
}
5850

5951
"LLMChain should fail when using just one input but expecting more" {
6052
val template = "My name is {name} and I'm {age} years old"
6153
either {
6254
val prompt = PromptTemplate(template, listOf("name", "age"))
63-
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
55+
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
6456
chain.run("foo").bind()
65-
} shouldBe Either.Left(
66-
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
67-
)
57+
} shouldBeLeft Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
6858
}
6959
})
7060

0 commit comments

Comments
 (0)