From d85896d43e66a67efd7699d7b85c78b6f3e658e5 Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 13:28:11 +0200 Subject: [PATCH 1/9] refactor: onlyOutputs -> returnAll --- .../com/xebia/functional/chains/Chain.kt | 11 ++++--- .../com/xebia/functional/chains/LLMChain.kt | 33 ++++++++++--------- .../functional/chains/SimpleSequenceChain.kt | 2 ++ 3 files changed, 26 insertions(+), 20 deletions(-) create mode 100644 src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt index bed973de2..84482d2a2 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt @@ -1,6 +1,7 @@ package com.xebia.functional.chains import arrow.core.Either +import arrow.core.NonEmptySet import arrow.core.raise.either import arrow.core.raise.ensure @@ -10,7 +11,7 @@ interface Chain { data class Config( val inputKeys: Set, val outputKeys: Set, - val onlyOutputs: Boolean + val returnAll: Boolean = false ) { fun createInputs( inputs: String @@ -39,24 +40,24 @@ interface Chain { val config: Config - suspend fun call(inputs: Map): Map + suspend fun call(inputs: Map): Either> suspend fun run(input: String): Either> = either { val preparedInputs = config.createInputs(input).bind() - val result = call(preparedInputs) + val result = call(preparedInputs).bind() prepareOutputs(preparedInputs, result) } suspend fun run(inputs: Map): Either> = either { val preparedInputs = config.createInputs(inputs).bind() - val result = call(preparedInputs) + val result = call(preparedInputs).bind() prepareOutputs(preparedInputs, result) } private fun prepareOutputs( inputs: Map, outputs: Map ): Map = - if (config.onlyOutputs) outputs else inputs + outputs + if (config.returnAll) inputs + outputs else outputs } diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt index cb6a685e7..abdb88b13 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt @@ -1,5 +1,7 @@ package com.xebia.functional.chains +import arrow.core.Either +import arrow.core.raise.either import com.xebia.functional.llm.openai.CompletionChoice import com.xebia.functional.llm.openai.CompletionRequest import com.xebia.functional.llm.openai.OpenAIClient @@ -14,26 +16,27 @@ suspend fun LLMChain( echo: Boolean, n: Int, temperature: Double, - onlyOutputs: Boolean + returnAll: Boolean = false ): Chain = object : Chain { - override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), onlyOutputs) + override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), returnAll) - override suspend fun call(inputs: Map): Map { - val prompt = promptTemplate.format(inputs) + override suspend fun call(inputs: Map): Either> = + either { + val prompt = promptTemplate.format(inputs) - val request = CompletionRequest( - model = llmModel, - user = user, - prompt = prompt, - echo = echo, - n = n, - temperature = temperature, - ) + val request = CompletionRequest( + model = llmModel, + user = user, + prompt = prompt, + echo = echo, + n = n, + temperature = temperature, + ) - val completions = llm.createCompletion(request) - return formatOutput(completions) - } + val completions = llm.createCompletion(request) + formatOutput(completions) + } private fun formatOutput(completions: List): Map = config.outputKeys.associateWith { diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt new file mode 100644 index 000000000..ed647b5ca --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -0,0 +1,2 @@ +package com.xebia.functional.chains + From 2e92d6c0e169b66639586f1c2b81f7caba043188 Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 13:28:32 +0200 Subject: [PATCH 2/9] config: add arrow bundle in gradle --- build.gradle.kts | 5 ++--- gradle/libs.versions.toml | 8 +++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 51a00b7bd..00f8259e5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -46,10 +46,9 @@ kotlin { sourceSets { commonMain { dependencies { - implementation(libs.arrow.fx) - implementation(libs.arrow.resilience) - implementation(libs.kotlinx.serialization.json) + implementation(libs.bundles.arrow) implementation(libs.bundles.ktor.client) + implementation(libs.kotlinx.serialization.json) implementation(libs.okio) implementation(libs.uuid) implementation(libs.klogging) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9c0b3e511..9265c216e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,7 +16,8 @@ testcontainers = "1.17.6" hikari = "5.0.1" [libraries] -arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" } +arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" } +arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" } arrow-resilience = { module = "io.arrow-kt:arrow-resilience", version.ref = "arrow" } open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" } @@ -38,6 +39,11 @@ postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" } [bundles] +arrow = [ + "arrow-core", + "arrow-fx-coroutines", + "arrow-resilience" +] ktor-client = [ "ktor-client", "ktor-client-content-negotiation", From 9a90791eb383ab2de6fde73f152d93ca010ff6ed Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:02:07 +0200 Subject: [PATCH 3/9] feat: first simple sequencial chains --- .../com/xebia/functional/chains/Chain.kt | 13 ++-- .../xebia/functional/chains/SequenceChain.kt | 6 ++ .../functional/chains/SimpleSequenceChain.kt | 65 +++++++++++++++++++ .../chains/SimpleSequentialChainSpec.kt | 2 + 4 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt create mode 100644 src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt index 84482d2a2..645f16cb1 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt @@ -1,12 +1,13 @@ package com.xebia.functional.chains import arrow.core.Either -import arrow.core.NonEmptySet import arrow.core.raise.either import arrow.core.raise.ensure interface Chain { - data class InvalidInputs(val reason: String) + + sealed class Error(open val reason: String) + data class InvalidInputs(override val reason: String): Error(reason) data class Config( val inputKeys: Set, @@ -29,7 +30,7 @@ interface Chain { ): Either> = either { ensure((inputKeys subtract inputs.keys).isEmpty()) { - InvalidInputs("The provided inputs: " + + InvalidInputs("The provided inputs: " + inputs.keys.joinToString(", ") { "{$it}" } + " do not match with chain's inputs: " + inputKeys.joinToString(", ") { "{$it}" }) @@ -40,16 +41,16 @@ interface Chain { val config: Config - suspend fun call(inputs: Map): Either> + suspend fun call(inputs: Map): Either> - suspend fun run(input: String): Either> = + suspend fun run(input: String): Either> = either { val preparedInputs = config.createInputs(input).bind() val result = call(preparedInputs).bind() prepareOutputs(preparedInputs, result) } - suspend fun run(inputs: Map): Either> = + suspend fun run(inputs: Map): Either> = either { val preparedInputs = config.createInputs(inputs).bind() val result = call(preparedInputs).bind() diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt new file mode 100644 index 000000000..6ad43b400 --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt @@ -0,0 +1,6 @@ +package com.xebia.functional.chains + +interface SequenceChain : Chain { + data class InvalidOutputs(override val reason: String): Chain.Error(reason) + data class InvalidInputsAndOutputs(override val reason: String): Chain.Error(reason) +} \ No newline at end of file diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index ed647b5ca..a5c2b2762 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -1,2 +1,67 @@ package com.xebia.functional.chains +import arrow.core.Either +import arrow.core.mapOrAccumulate +import arrow.core.raise.Raise +import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.ensureNotNull + +fun Raise.SimpleSequentialChain( + chains: List, inputKey: String, outputKey: String, returnAll: Boolean +): SimpleSequentialChain = + SimpleSequentialChain.either(chains, inputKey, outputKey, returnAll).bind() + +class SimpleSequentialChain( + private val chains: List, + private val inputKey: String = "input", + private val outputKey: String = "output", + returnAll: Boolean = false +) : SequenceChain { + + override val config = Chain.Config(setOf(inputKey), setOf(outputKey), returnAll) + + override suspend fun call(inputs: Map): Either> = + either { + val input = validateInput(inputs, inputKey) + val firstRes = chains.first().run(input).bind() + val chainRes = chains.drop(1).fold(firstRes) { acc, chain -> + chain.run(acc).bind() + }.values.first() + mapOf(outputKey to chainRes) + } + + companion object { + fun either( + chains: List, inputKey: String, outputKey: String, returnAll: Boolean + ): Either { + return chains.mapOrAccumulate { chain -> + with(chain.config) { + validateInputKeys(inputKeys) + validateOutputKeys(outputKeys) + } + }.mapLeft { + SequenceChain.InvalidInputsAndOutputs(it.joinToString(transform = Chain.Error::reason)) + }.map { SimpleSequentialChain(chains, inputKey, outputKey, returnAll) } + } + } +} + +private fun Raise.validateOutputKeys(outputKeys: Set): Unit = + ensure(outputKeys.size > 1) { + SequenceChain.InvalidOutputs("The expected outputs are more than one: " + + outputKeys.joinToString(", ") { "{$it}" }) + } + +private fun Raise.validateInputKeys(inputKeys: Set): Unit = + ensure(inputKeys.size > 1) { + Chain.InvalidInputs("The expected inputs are more than one: " + + inputKeys.joinToString(", ") { "{$it}" }) + } + +private fun Raise.validateInput(inputs: Map, inputKey: String): String = + ensureNotNull(inputs[inputKey]) { + Chain.InvalidInputs("The provided inputs: " + + inputs.keys.joinToString(", ") { "{$it}" } + + " do not match with chain's input: {$inputKey}") + } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt new file mode 100644 index 000000000..ed647b5ca --- /dev/null +++ b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt @@ -0,0 +1,2 @@ +package com.xebia.functional.chains + From 1b9ac0901078438477c29e431db04af73cf06d04 Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:02:23 +0200 Subject: [PATCH 4/9] test: refactor tests after changing names --- .../kotlin/com/xebia/functional/chains/ConfigSpec.kt | 10 +++++----- .../kotlin/com/xebia/functional/chains/LLMChainSpec.kt | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt index 5396ca9bc..f4658fc00 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/ConfigSpec.kt @@ -7,19 +7,19 @@ import io.kotest.matchers.shouldBe class ConfigSpec : StringSpec({ "Chain Config should return the inputs properly" { - val config = Chain.Config(setOf("name", "age"), setOf("text"), false) + val config = Chain.Config(setOf("name", "age"), setOf("text")) val result = config.createInputs(mapOf("name" to "foo", "age" to "bar")) result shouldBe Either.Right(mapOf("name" to "foo", "age" to "bar")) } "Chain Config should return the input as a Map" { - val config = Chain.Config(setOf("input"), setOf("text"), false) + val config = Chain.Config(setOf("input"), setOf("text")) val result = config.createInputs("foo") result shouldBe Either.Right(mapOf("input" to "foo")) } "Chain Config should fail when inputs set doesn't contain all inputKeys" { - val config = Chain.Config(setOf("name", "age"), setOf("text"), false) + val config = Chain.Config(setOf("name", "age"), setOf("text")) val result = config.createInputs(mapOf("name" to "foo")) result shouldBe Either.Left( Chain.InvalidInputs("The provided inputs: {name} do not match with chain's inputs: {name}, {age}") @@ -27,7 +27,7 @@ class ConfigSpec : StringSpec({ } "Chain Config should fail when inputs set has different inputKeys" { - val config = Chain.Config(setOf("name", "age"), setOf("text"), false) + val config = Chain.Config(setOf("name", "age"), setOf("text")) val result = config.createInputs(mapOf("name" to "foo", "city" to "NY")) result shouldBe Either.Left( Chain.InvalidInputs("The provided inputs: {name}, {city} do not match with chain's inputs: {name}, {age}") @@ -35,7 +35,7 @@ class ConfigSpec : StringSpec({ } "Chain Config should fail when input is just one and expects more" { - val config = Chain.Config(setOf("name", "age"), setOf("text"), false) + val config = Chain.Config(setOf("name", "age"), setOf("text")) val result = config.createInputs("foo") result shouldBe Either.Left( Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt index c1eff34c5..8e363bca8 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt @@ -16,7 +16,7 @@ class LLMChainSpec : StringSpec({ val template = "Tell me {foo}." either { val prompt = PromptTemplate(template, listOf("foo")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run("a joke").bind() } shouldBe Either.Right( mapOf("answer" to "I'm not good at jokes") @@ -27,7 +27,7 @@ class LLMChainSpec : StringSpec({ val template = "Tell me {foo}." either { val prompt = PromptTemplate(template, listOf("foo")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) chain.run("a joke").bind() } shouldBe Either.Right( mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") @@ -38,7 +38,7 @@ class LLMChainSpec : StringSpec({ val template = "My name is {name} and I'm {age} years old" either { val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) chain.run(mapOf("age" to "28", "name" to "foo")).bind() } shouldBe Either.Right( mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") @@ -49,7 +49,7 @@ class LLMChainSpec : StringSpec({ val template = "My name is {name} and I'm {age} years old" either { val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run(mapOf("age" to "28", "brand" to "foo")).bind() } shouldBe Either.Left( Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") @@ -60,7 +60,7 @@ class LLMChainSpec : StringSpec({ val template = "My name is {name} and I'm {age} years old" either { val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run("foo").bind() } shouldBe Either.Left( Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") From 1364b5f4c1af32947061b106399669a6086b9592 Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:50:24 +0200 Subject: [PATCH 5/9] test: create simple sequence chain spec and adjust class --- .../xebia/functional/chains/SequenceChain.kt | 2 +- .../functional/chains/SimpleSequenceChain.kt | 18 ++--- .../xebia/functional/chains/LLMChainSpec.kt | 30 +++------ .../chains/SimpleSequentialChainSpec.kt | 67 +++++++++++++++++++ 4 files changed, 85 insertions(+), 32 deletions(-) diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt index 6ad43b400..14652264e 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt @@ -2,5 +2,5 @@ package com.xebia.functional.chains interface SequenceChain : Chain { data class InvalidOutputs(override val reason: String): Chain.Error(reason) - data class InvalidInputsAndOutputs(override val reason: String): Chain.Error(reason) + data class InvalidKeys(override val reason: String): Chain.Error(reason) } \ No newline at end of file diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index a5c2b2762..44dec489c 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -8,15 +8,15 @@ import arrow.core.raise.ensure import arrow.core.raise.ensureNotNull fun Raise.SimpleSequentialChain( - chains: List, inputKey: String, outputKey: String, returnAll: Boolean + chains: List, inputKey: String = "input", outputKey: String = "output", returnAll: Boolean = false ): SimpleSequentialChain = SimpleSequentialChain.either(chains, inputKey, outputKey, returnAll).bind() -class SimpleSequentialChain( +class SimpleSequentialChain private constructor( private val chains: List, - private val inputKey: String = "input", - private val outputKey: String = "output", - returnAll: Boolean = false + private val inputKey: String, + private val outputKey: String, + returnAll: Boolean ) : SequenceChain { override val config = Chain.Config(setOf(inputKey), setOf(outputKey), returnAll) @@ -34,27 +34,27 @@ class SimpleSequentialChain( companion object { fun either( chains: List, inputKey: String, outputKey: String, returnAll: Boolean - ): Either { + ): Either { return chains.mapOrAccumulate { chain -> with(chain.config) { validateInputKeys(inputKeys) validateOutputKeys(outputKeys) } }.mapLeft { - SequenceChain.InvalidInputsAndOutputs(it.joinToString(transform = Chain.Error::reason)) + SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason)) }.map { SimpleSequentialChain(chains, inputKey, outputKey, returnAll) } } } } private fun Raise.validateOutputKeys(outputKeys: Set): Unit = - ensure(outputKeys.size > 1) { + ensure(outputKeys.size == 1) { SequenceChain.InvalidOutputs("The expected outputs are more than one: " + outputKeys.joinToString(", ") { "{$it}" }) } private fun Raise.validateInputKeys(inputKeys: Set): Unit = - ensure(inputKeys.size > 1) { + ensure(inputKeys.size == 1) { Chain.InvalidInputs("The expected inputs are more than one: " + inputKeys.joinToString(", ") { "{$it}" }) } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt index 8e363bca8..b98dee4e5 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt @@ -1,15 +1,11 @@ package com.xebia.functional.chains -import arrow.core.Either import arrow.core.raise.either -import com.xebia.functional.llm.openai.CompletionChoice -import com.xebia.functional.llm.openai.CompletionRequest -import com.xebia.functional.llm.openai.EmbeddingRequest -import com.xebia.functional.llm.openai.EmbeddingResult -import com.xebia.functional.llm.openai.OpenAIClient +import com.xebia.functional.llm.openai.* import com.xebia.functional.prompt.PromptTemplate +import io.kotest.assertions.arrow.core.shouldBeLeft +import io.kotest.assertions.arrow.core.shouldBeRight import io.kotest.core.spec.style.StringSpec -import io.kotest.matchers.shouldBe class LLMChainSpec : StringSpec({ "LLMChain should return a prediction with just the output" { @@ -18,9 +14,7 @@ class LLMChainSpec : StringSpec({ val prompt = PromptTemplate(template, listOf("foo")) val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run("a joke").bind() - } shouldBe Either.Right( - mapOf("answer" to "I'm not good at jokes") - ) + } shouldBeRight mapOf("answer" to "I'm not good at jokes") } "LLMChain should return a prediction with both output and inputs" { @@ -29,9 +23,7 @@ class LLMChainSpec : StringSpec({ val prompt = PromptTemplate(template, listOf("foo")) val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) chain.run("a joke").bind() - } shouldBe Either.Right( - mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") - ) + } shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") } "LLMChain should return a prediction with a more complex template" { @@ -40,9 +32,7 @@ class LLMChainSpec : StringSpec({ val prompt = PromptTemplate(template, listOf("name", "age")) val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) chain.run(mapOf("age" to "28", "name" to "foo")).bind() - } shouldBe Either.Right( - mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") - ) + } shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") } "LLMChain should fail when inputs are not the expected ones from the PromptTemplate" { @@ -51,9 +41,7 @@ class LLMChainSpec : StringSpec({ val prompt = PromptTemplate(template, listOf("name", "age")) val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run(mapOf("age" to "28", "brand" to "foo")).bind() - } shouldBe Either.Left( - Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") - ) + } shouldBeLeft Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") } "LLMChain should fail when using just one input but expecting more" { @@ -62,9 +50,7 @@ class LLMChainSpec : StringSpec({ val prompt = PromptTemplate(template, listOf("name", "age")) val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) chain.run("foo").bind() - } shouldBe Either.Left( - Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") - ) + } shouldBeLeft Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") } }) diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt index ed647b5ca..9ef4cd45f 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt @@ -1,2 +1,69 @@ package com.xebia.functional.chains +import arrow.core.Either +import arrow.core.raise.either +import io.kotest.assertions.arrow.core.shouldBeLeft +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.core.spec.style.StringSpec + +class SimpleSequentialChainSpec : StringSpec({ + "SimpleSequenceChain should return a prediction with one chain" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chains = listOf(chain1) + + either { + val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + ssc.run(mapOf("input" to "123")).bind() + } shouldBeRight mapOf("input" to "123", "output" to "123dr") + } + + "SimpleSequenceChain should return a prediction with more than one chain" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chain3 = FakeChain(inputVariables = setOf("baz"), outputVariables = setOf("dre")) + val chains = listOf(chain1, chain2, chain3) + + either { + val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + ssc.run(mapOf("input" to "123")).bind() + } shouldBeRight mapOf("input" to "123", "output" to "123drdrdr") + } + + "SimpleSequentialChain should fail if multiple input variables are expected" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + ssc.run(mapOf("input" to "123")).bind() + } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {bar}, {foo}") + } + + "SimpleSequentialChain should fail if multiple output variables are expected" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "foo")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + ssc.run(mapOf("input" to "123")).bind() + } shouldBeLeft SequenceChain.InvalidKeys("The expected outputs are more than one: {bar}, {foo}") + } +}) + +data class FakeChain(private val inputVariables: Set, private val outputVariables: Set) : Chain { + override val config: Chain.Config = Chain.Config( + inputKeys = inputVariables, + outputKeys = outputVariables, + returnAll = false + ) + + override suspend fun call(inputs: Map): Either> = + either { + val variables = inputVariables.map { inputs[it] }.requireNoNulls() + outputVariables.fold(emptyMap()) { outputs, outputVar -> + outputs + (outputVar to "${variables.joinToString(separator = "")}dr") + } + } +} From 9b7874918484b4dcc08a86fa6941ee911fbdad4a Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 23:29:11 +0200 Subject: [PATCH 6/9] feat: accumulate all errors on simple sequence chain --- .../functional/chains/SimpleSequenceChain.kt | 22 +++++++++---------- .../chains/SimpleSequentialChainSpec.kt | 11 ++++++++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index 44dec489c..5a76c32cd 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -1,11 +1,8 @@ package com.xebia.functional.chains import arrow.core.Either -import arrow.core.mapOrAccumulate -import arrow.core.raise.Raise -import arrow.core.raise.either -import arrow.core.raise.ensure -import arrow.core.raise.ensureNotNull +import arrow.core.NonEmptyList +import arrow.core.raise.* fun Raise.SimpleSequentialChain( chains: List, inputKey: String = "input", outputKey: String = "output", returnAll: Boolean = false @@ -34,16 +31,19 @@ class SimpleSequentialChain private constructor( companion object { fun either( chains: List, inputKey: String, outputKey: String, returnAll: Boolean - ): Either { - return chains.mapOrAccumulate { chain -> - with(chain.config) { - validateInputKeys(inputKeys) - validateOutputKeys(outputKeys) + ): Either = + either { + chains.map { chain -> + either, Chain> { + zipOrAccumulate( + { validateInputKeys(chain.config.inputKeys) }, + { validateOutputKeys(chain.config.outputKeys) } + ) { _, _ -> chain } + }.bind() } }.mapLeft { SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason)) }.map { SimpleSequentialChain(chains, inputKey, outputKey, returnAll) } - } } } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt index 9ef4cd45f..99f7f2f1f 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt @@ -50,6 +50,17 @@ class SimpleSequentialChainSpec : StringSpec({ ssc.run(mapOf("input" to "123")).bind() } shouldBeLeft SequenceChain.InvalidKeys("The expected outputs are more than one: {bar}, {foo}") } + + "SimpleSequentialChain should fail if multiple input and output variables are expected" { + val chain1 = FakeChain(inputVariables = setOf("foo", "bar"), outputVariables = setOf("bar", "foo")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + ssc.run(mapOf("input" to "123")).bind() + } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, The expected outputs are more than one: {bar}, {foo}") + } }) data class FakeChain(private val inputVariables: Set, private val outputVariables: Set) : Chain { From e23ba836c11459613428855471e57c7a42fe6a2b Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Thu, 27 Apr 2023 23:48:58 +0200 Subject: [PATCH 7/9] refactor: remove wildcard --- .../com/xebia/functional/chains/SimpleSequenceChain.kt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index 5a76c32cd..5f7cc8640 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -2,7 +2,11 @@ package com.xebia.functional.chains import arrow.core.Either import arrow.core.NonEmptyList -import arrow.core.raise.* +import arrow.core.raise.Raise +import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.ensureNotNull +import arrow.core.raise.zipOrAccumulate fun Raise.SimpleSequentialChain( chains: List, inputKey: String = "input", outputKey: String = "output", returnAll: Boolean = false From 38fd29cd3022720b8905e3d690d8c931683a2f2f Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Tue, 2 May 2023 08:53:07 +0200 Subject: [PATCH 8/9] feat: apply suggestion of representing chain output model as an enum class --- .../kotlin/com/xebia/functional/chains/Chain.kt | 10 ++++++++-- .../com/xebia/functional/chains/LLMChain.kt | 4 ++-- .../functional/chains/SimpleSequenceChain.kt | 15 +++++++++------ .../com/xebia/functional/chains/LLMChainSpec.kt | 10 +++++++--- .../chains/SimpleSequentialChainSpec.kt | 16 ++++++++-------- 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt index 645f16cb1..35254472f 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt @@ -6,13 +6,16 @@ import arrow.core.raise.ensure interface Chain { + enum class ChainOutput { InputAndOutput, OnlyOutput } + sealed class Error(open val reason: String) + data class InvalidInputs(override val reason: String): Error(reason) data class Config( val inputKeys: Set, val outputKeys: Set, - val returnAll: Boolean = false + val chainOutput: ChainOutput = ChainOutput.OnlyOutput ) { fun createInputs( inputs: String @@ -60,5 +63,8 @@ interface Chain { private fun prepareOutputs( inputs: Map, outputs: Map ): Map = - if (config.returnAll) inputs + outputs else outputs + when (config.chainOutput) { + ChainOutput.InputAndOutput -> inputs + outputs + ChainOutput.OnlyOutput -> outputs + } } diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt index abdb88b13..4c4c0a571 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt @@ -16,10 +16,10 @@ suspend fun LLMChain( echo: Boolean, n: Int, temperature: Double, - returnAll: Boolean = false + chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput ): Chain = object : Chain { - override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), returnAll) + override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), chainOutput) override suspend fun call(inputs: Map): Either> = either { diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index 5f7cc8640..d86ca86f4 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -9,18 +9,21 @@ import arrow.core.raise.ensureNotNull import arrow.core.raise.zipOrAccumulate fun Raise.SimpleSequentialChain( - chains: List, inputKey: String = "input", outputKey: String = "output", returnAll: Boolean = false + chains: List, + inputKey: String = "input", + outputKey: String = "output", + chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput ): SimpleSequentialChain = - SimpleSequentialChain.either(chains, inputKey, outputKey, returnAll).bind() + SimpleSequentialChain.either(chains, inputKey, outputKey, chainOutput).bind() class SimpleSequentialChain private constructor( private val chains: List, private val inputKey: String, private val outputKey: String, - returnAll: Boolean + chainOutput: Chain.ChainOutput ) : SequenceChain { - override val config = Chain.Config(setOf(inputKey), setOf(outputKey), returnAll) + override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput) override suspend fun call(inputs: Map): Either> = either { @@ -34,7 +37,7 @@ class SimpleSequentialChain private constructor( companion object { fun either( - chains: List, inputKey: String, outputKey: String, returnAll: Boolean + chains: List, inputKey: String, outputKey: String, chainOutput: Chain.ChainOutput ): Either = either { chains.map { chain -> @@ -47,7 +50,7 @@ class SimpleSequentialChain private constructor( } }.mapLeft { SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason)) - }.map { SimpleSequentialChain(chains, inputKey, outputKey, returnAll) } + }.map { SimpleSequentialChain(chains, inputKey, outputKey, chainOutput) } } } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt index b98dee4e5..24b7e22ab 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/LLMChainSpec.kt @@ -1,7 +1,11 @@ package com.xebia.functional.chains import arrow.core.raise.either -import com.xebia.functional.llm.openai.* +import com.xebia.functional.llm.openai.CompletionChoice +import com.xebia.functional.llm.openai.CompletionRequest +import com.xebia.functional.llm.openai.EmbeddingRequest +import com.xebia.functional.llm.openai.EmbeddingResult +import com.xebia.functional.llm.openai.OpenAIClient import com.xebia.functional.prompt.PromptTemplate import io.kotest.assertions.arrow.core.shouldBeLeft import io.kotest.assertions.arrow.core.shouldBeRight @@ -21,7 +25,7 @@ class LLMChainSpec : StringSpec({ val template = "Tell me {foo}." either { val prompt = PromptTemplate(template, listOf("foo")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput) chain.run("a joke").bind() } shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") } @@ -30,7 +34,7 @@ class LLMChainSpec : StringSpec({ val template = "My name is {name} and I'm {age} years old" either { val prompt = PromptTemplate(template, listOf("name", "age")) - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput) chain.run(mapOf("age" to "28", "name" to "foo")).bind() } shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt index 99f7f2f1f..62a5ec3e7 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt @@ -12,7 +12,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1) either { - val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeRight mapOf("input" to "123", "output" to "123dr") } @@ -24,7 +24,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2, chain3) either { - val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeRight mapOf("input" to "123", "output" to "123drdrdr") } @@ -35,7 +35,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2) either { - val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {bar}, {foo}") } @@ -46,7 +46,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2) either { - val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeLeft SequenceChain.InvalidKeys("The expected outputs are more than one: {bar}, {foo}") } @@ -57,17 +57,17 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2) either { - val ssc = SimpleSequentialChain(chains = chains, returnAll = true) + val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() - } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, The expected outputs are more than one: {bar}, {foo}") + } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, " + + "The expected outputs are more than one: {bar}, {foo}") } }) data class FakeChain(private val inputVariables: Set, private val outputVariables: Set) : Chain { override val config: Chain.Config = Chain.Config( inputKeys = inputVariables, - outputKeys = outputVariables, - returnAll = false + outputKeys = outputVariables ) override suspend fun call(inputs: Map): Either> = From 3b2313d90094a402fd40324738a90292bdc23752 Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Tue, 2 May 2023 11:52:16 +0200 Subject: [PATCH 9/9] feat: apply suggestion of removing nested either of SimpleSequenceChain --- .../functional/chains/SimpleSequenceChain.kt | 31 ++++++++++--------- .../chains/SimpleSequentialChainSpec.kt | 20 +++--------- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index d86ca86f4..9f685406c 100644 --- a/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -1,22 +1,22 @@ package com.xebia.functional.chains import arrow.core.Either -import arrow.core.NonEmptyList import arrow.core.raise.Raise import arrow.core.raise.either import arrow.core.raise.ensure import arrow.core.raise.ensureNotNull +import arrow.core.raise.recover import arrow.core.raise.zipOrAccumulate -fun Raise.SimpleSequentialChain( +fun Raise.SimpleSequenceChain( chains: List, inputKey: String = "input", outputKey: String = "output", chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput -): SimpleSequentialChain = - SimpleSequentialChain.either(chains, inputKey, outputKey, chainOutput).bind() +): SimpleSequenceChain = + SimpleSequenceChain.either(chains, inputKey, outputKey, chainOutput).bind() -class SimpleSequentialChain private constructor( +class SimpleSequenceChain private constructor( private val chains: List, private val inputKey: String, private val outputKey: String, @@ -37,20 +37,21 @@ class SimpleSequentialChain private constructor( companion object { fun either( - chains: List, inputKey: String, outputKey: String, chainOutput: Chain.ChainOutput - ): Either = + chains: List, + inputKey: String, + outputKey: String, + chainOutput: Chain.ChainOutput + ): Either = either { - chains.map { chain -> - either, Chain> { + val mappedChains: List = chains.map { chain -> + recover({ zipOrAccumulate( { validateInputKeys(chain.config.inputKeys) }, - { validateOutputKeys(chain.config.outputKeys) } - ) { _, _ -> chain } - }.bind() + { validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain } + }) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) } } - }.mapLeft { - SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason)) - }.map { SimpleSequentialChain(chains, inputKey, outputKey, chainOutput) } + SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput) + } } } diff --git a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt index 62a5ec3e7..077139aba 100644 --- a/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt +++ b/src/commonTest/kotlin/com/xebia/functional/chains/SimpleSequentialChainSpec.kt @@ -12,7 +12,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1) either { - val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) + val ssc = SimpleSequenceChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeRight mapOf("input" to "123", "output" to "123dr") } @@ -24,7 +24,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2, chain3) either { - val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) + val ssc = SimpleSequenceChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeRight mapOf("input" to "123", "output" to "123drdrdr") } @@ -35,7 +35,7 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2) either { - val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) + val ssc = SimpleSequenceChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {bar}, {foo}") } @@ -46,22 +46,10 @@ class SimpleSequentialChainSpec : StringSpec({ val chains = listOf(chain1, chain2) either { - val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) + val ssc = SimpleSequenceChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) ssc.run(mapOf("input" to "123")).bind() } shouldBeLeft SequenceChain.InvalidKeys("The expected outputs are more than one: {bar}, {foo}") } - - "SimpleSequentialChain should fail if multiple input and output variables are expected" { - val chain1 = FakeChain(inputVariables = setOf("foo", "bar"), outputVariables = setOf("bar", "foo")) - val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) - val chains = listOf(chain1, chain2) - - either { - val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput) - ssc.run(mapOf("input" to "123")).bind() - } shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, " + - "The expected outputs are more than one: {bar}, {foo}") - } }) data class FakeChain(private val inputVariables: Set, private val outputVariables: Set) : Chain {