Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Simple Sequence Chain + Unit tests #12

Merged
merged 9 commits into from
May 2, 2023
5 changes: 2 additions & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ kotlin {
sourceSets {
commonMain {
dependencies {
implementation(libs.arrow.fx)
implementation(libs.arrow.resilience)
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.arrow)
implementation(libs.bundles.ktor.client)
implementation(libs.kotlinx.serialization.json)
implementation(libs.okio)
implementation(libs.uuid)
implementation(libs.klogging)
Expand Down
8 changes: 7 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ testcontainers = "1.17.6"
hikari = "5.0.1"

[libraries]
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
arrow-resilience = { module = "io.arrow-kt:arrow-resilience", version.ref = "arrow" }
open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }
Expand All @@ -38,6 +39,11 @@ postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql"
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }

[bundles]
arrow = [
"arrow-core",
"arrow-fx-coroutines",
"arrow-resilience"
]
ktor-client = [
"ktor-client",
"ktor-client-content-negotiation",
Expand Down
26 changes: 17 additions & 9 deletions src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ import arrow.core.raise.either
import arrow.core.raise.ensure

interface Chain {
data class InvalidInputs(val reason: String)

enum class ChainOutput { InputAndOutput, OnlyOutput }

sealed class Error(open val reason: String)

data class InvalidInputs(override val reason: String): Error(reason)

data class Config(
val inputKeys: Set<String>,
val outputKeys: Set<String>,
val onlyOutputs: Boolean
val chainOutput: ChainOutput = ChainOutput.OnlyOutput
) {
fun createInputs(
inputs: String
Expand All @@ -28,7 +33,7 @@ interface Chain {
): Either<InvalidInputs, Map<String, String>> =
either {
ensure((inputKeys subtract inputs.keys).isEmpty()) {
InvalidInputs("The provided inputs: " +
InvalidInputs("The provided inputs: " +
inputs.keys.joinToString(", ") { "{$it}" } +
" do not match with chain's inputs: " +
inputKeys.joinToString(", ") { "{$it}" })
Expand All @@ -39,24 +44,27 @@ interface Chain {

val config: Config

suspend fun call(inputs: Map<String, String>): Map<String, String>
suspend fun call(inputs: Map<String, String>): Either<Error, Map<String, String>>

suspend fun run(input: String): Either<InvalidInputs, Map<String, String>> =
suspend fun run(input: String): Either<Error, Map<String, String>> =
either {
val preparedInputs = config.createInputs(input).bind()
val result = call(preparedInputs)
val result = call(preparedInputs).bind()
prepareOutputs(preparedInputs, result)
}

suspend fun run(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> =
suspend fun run(inputs: Map<String, String>): Either<Error, Map<String, String>> =
either {
val preparedInputs = config.createInputs(inputs).bind()
val result = call(preparedInputs)
val result = call(preparedInputs).bind()
prepareOutputs(preparedInputs, result)
}

private fun prepareOutputs(
inputs: Map<String, String>, outputs: Map<String, String>
): Map<String, String> =
if (config.onlyOutputs) outputs else inputs + outputs
when (config.chainOutput) {
ChainOutput.InputAndOutput -> inputs + outputs
ChainOutput.OnlyOutput -> outputs
}
}
33 changes: 18 additions & 15 deletions src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.either
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.OpenAIClient
Expand All @@ -14,26 +16,27 @@ suspend fun LLMChain(
echo: Boolean,
n: Int,
temperature: Double,
onlyOutputs: Boolean
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): Chain = object : Chain {

override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), onlyOutputs)
override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), chainOutput)

override suspend fun call(inputs: Map<String, String>): Map<String, String> {
val prompt = promptTemplate.format(inputs)
override suspend fun call(inputs: Map<String, String>): Either<Chain.InvalidInputs, Map<String, String>> =
either {
val prompt = promptTemplate.format(inputs)

val request = CompletionRequest(
model = llmModel,
user = user,
prompt = prompt,
echo = echo,
n = n,
temperature = temperature,
)
val request = CompletionRequest(
model = llmModel,
user = user,
prompt = prompt,
echo = echo,
n = n,
temperature = temperature,
)

val completions = llm.createCompletion(request)
return formatOutput(completions)
}
val completions = llm.createCompletion(request)
formatOutput(completions)
}

private fun formatOutput(completions: List<CompletionChoice>): Map<String, String> =
config.outputKeys.associateWith {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.chains

interface SequenceChain : Chain {
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
data class InvalidKeys(override val reason: String): Chain.Error(reason)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.ensureNotNull
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate

fun Raise<Chain.Error>.SimpleSequenceChain(
chains: List<Chain>,
inputKey: String = "input",
outputKey: String = "output",
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): SimpleSequenceChain =
SimpleSequenceChain.either(chains, inputKey, outputKey, chainOutput).bind()

class SimpleSequenceChain private constructor(
private val chains: List<Chain>,
private val inputKey: String,
private val outputKey: String,
chainOutput: Chain.ChainOutput
) : SequenceChain {

override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
either {
val input = validateInput(inputs, inputKey)
val firstRes = chains.first().run(input).bind()
val chainRes = chains.drop(1).fold(firstRes) { acc, chain ->
chain.run(acc).bind()
}.values.first()
mapOf(outputKey to chainRes)
}

companion object {
fun either(
chains: List<Chain>,
inputKey: String,
outputKey: String,
chainOutput: Chain.ChainOutput
): Either<SequenceChain.InvalidKeys, SimpleSequenceChain> =
either {
val mappedChains: List<Chain> = chains.map { chain ->
recover({
zipOrAccumulate(
{ validateInputKeys(chain.config.inputKeys) },
{ validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain }
}) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) }
}
SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput)
}
}
}

private fun Raise<SequenceChain.InvalidOutputs>.validateOutputKeys(outputKeys: Set<String>): Unit =
ensure(outputKeys.size == 1) {
SequenceChain.InvalidOutputs("The expected outputs are more than one: " +
outputKeys.joinToString(", ") { "{$it}" })
}

private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>): Unit =
ensure(inputKeys.size == 1) {
Chain.InvalidInputs("The expected inputs are more than one: " +
inputKeys.joinToString(", ") { "{$it}" })
}

private fun Raise<Chain.InvalidInputs>.validateInput(inputs: Map<String, String>, inputKey: String): String =
ensureNotNull(inputs[inputKey]) {
Chain.InvalidInputs("The provided inputs: " +
inputs.keys.joinToString(", ") { "{$it}" } +
" do not match with chain's input: {$inputKey}")
}
10 changes: 5 additions & 5 deletions src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,35 @@ import io.kotest.matchers.shouldBe
class ConfigSpec : StringSpec({

"Chain Config should return the inputs properly" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val config = Chain.Config(setOf("name", "age"), setOf("text"))
val result = config.createInputs(mapOf("name" to "foo", "age" to "bar"))
result shouldBe Either.Right(mapOf("name" to "foo", "age" to "bar"))
}

"Chain Config should return the input as a Map" {
val config = Chain.Config(setOf("input"), setOf("text"), false)
val config = Chain.Config(setOf("input"), setOf("text"))
val result = config.createInputs("foo")
result shouldBe Either.Right(mapOf("input" to "foo"))
}

"Chain Config should fail when inputs set doesn't contain all inputKeys" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val config = Chain.Config(setOf("name", "age"), setOf("text"))
val result = config.createInputs(mapOf("name" to "foo"))
result shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {name} do not match with chain's inputs: {name}, {age}")
)
}

"Chain Config should fail when inputs set has different inputKeys" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val config = Chain.Config(setOf("name", "age"), setOf("text"))
val result = config.createInputs(mapOf("name" to "foo", "city" to "NY"))
result shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}")
)
}

"Chain Config should fail when input is just one and expects more" {
val config = Chain.Config(setOf("name", "age"), setOf("text"), false)
val config = Chain.Config(setOf("name", "age"), setOf("text"))
val result = config.createInputs("foo")
result shouldBe Either.Left(
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
Expand Down
34 changes: 12 additions & 22 deletions src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt
Original file line number Diff line number Diff line change
@@ -1,70 +1,60 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.either
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.EmbeddingResult
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class LLMChainSpec : StringSpec({
"LLMChain should return a prediction with just the output" {
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
chain.run("a joke").bind()
} shouldBe Either.Right(
mapOf("answer" to "I'm not good at jokes")
)
} shouldBeRight mapOf("answer" to "I'm not good at jokes")
}

"LLMChain should return a prediction with both output and inputs" {
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
chain.run("a joke").bind()
} shouldBe Either.Right(
mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
)
} shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
}

"LLMChain should return a prediction with a more complex template" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("age" to "28", "name" to "foo")).bind()
} shouldBe Either.Right(
mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
)
} shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
}

"LLMChain should fail when inputs are not the expected ones from the PromptTemplate" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
chain.run(mapOf("age" to "28", "brand" to "foo")).bind()
} shouldBe Either.Left(
Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
)
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}")
}

"LLMChain should fail when using just one input but expecting more" {
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0)
chain.run("foo").bind()
} shouldBe Either.Left(
Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
)
} shouldBeLeft Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}")
}
})

Expand Down
Loading