Skip to content

Commit 7b29ab5

Browse files
YawolfnomisRev
andauthored
SequenceChain Implementation (#29)
* Adding SequenceChain as a Chain * SequenceChain implementation * Update kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt Co-authored-by: Simon Vergauwen <[email protected]> * Update kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt Co-authored-by: Simon Vergauwen <[email protected]> * Some requested changes --------- Co-authored-by: yago <[email protected]> Co-authored-by: Simon Vergauwen <[email protected]>
1 parent a610e08 commit 7b29ab5

File tree

6 files changed

+217
-14
lines changed

6 files changed

+217
-14
lines changed

kotlin/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ interface Chain {
1414

1515
data class InvalidInputs(override val reason: String): Error(reason)
1616

17+
data class InvalidOutputs(override val reason: String): Error(reason)
18+
1719
data class Config(
1820
val inputKeys: Set<String>,
1921
val outputKeys: Set<String>,
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,91 @@
11
package com.xebia.functional.chains
22

3-
interface SequenceChain : Chain {
4-
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
5-
data class InvalidKeys(override val reason: String): Chain.Error(reason)
3+
import arrow.core.Either
4+
import arrow.core.flatten
5+
import arrow.core.raise.either
6+
import arrow.core.raise.ensure
7+
import arrow.core.raise.Raise
8+
import arrow.core.raise.recover
9+
import arrow.core.raise.zipOrAccumulate
10+
import arrow.core.raise.mapOrAccumulate
11+
12+
fun Raise<Chain.Error>.SequenceChain(
13+
chains: List<Chain>,
14+
inputVariables: List<String>,
15+
outputVariables: List<String>,
16+
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
17+
): SequenceChain =
18+
SequenceChain.either(chains, inputVariables, outputVariables, chainOutput).bind()
19+
20+
open class SequenceChain(
21+
private val chains: List<Chain>,
22+
private val inputVariables: List<String>,
23+
private val outputVariables: List<String>,
24+
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
25+
) : Chain {
26+
data class InvalidOutputs(override val reason: String) : Chain.Error(reason)
27+
data class InvalidKeys(override val reason: String) : Chain.Error(reason)
28+
29+
override val config = Chain.Config(inputVariables.toSet(), outputVariables.toSet(), chainOutput)
30+
31+
private val outputs = when (chainOutput) {
32+
Chain.ChainOutput.OnlyOutput -> outputVariables
33+
Chain.ChainOutput.InputAndOutput -> outputVariables.plus(inputVariables)
34+
}
35+
36+
override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
37+
either {
38+
val chainRes = chains.fold(inputs) { inputs0, chain ->
39+
chain.run(inputs0).map { inputs0 + it }.bind()
40+
}
41+
chainRes.filter { it.key in outputs }
42+
}
43+
44+
companion object {
45+
fun either(
46+
chains: List<Chain>,
47+
inputVariables: List<String>,
48+
outputVariables: List<String>,
49+
chainOutput: Chain.ChainOutput
50+
): Either<InvalidKeys, SequenceChain> =
51+
either {
52+
val allOutputs = chains.map { it.config.outputKeys }.toSet().flatten()
53+
val mappedChains: List<Chain> = recover({
54+
mapOrAccumulate(chains) { chain ->
55+
zipOrAccumulate(
56+
{ validateSequenceOutputs(outputVariables, allOutputs) },
57+
{ validateInputsOverlapping(inputVariables, allOutputs) },
58+
) { _, _ -> chain }
59+
}
60+
}) { raise(InvalidKeys(reason = it.flatten().joinToString(transform = Chain.Error::reason))) }
61+
SequenceChain(mappedChains, inputVariables, outputVariables, chainOutput)
62+
}
63+
}
664
}
65+
66+
private fun Raise<Chain.InvalidOutputs>.validateSequenceOutputs(
67+
sequenceOutputs: List<String>,
68+
chainOutputs: List<String>
69+
): Unit =
70+
ensure(sequenceOutputs.isNotEmpty() && sequenceOutputs.all { it in chainOutputs }) {
71+
Chain.InvalidOutputs("The provided outputs: " +
72+
sequenceOutputs.joinToString(", ") { "{$it}" } +
73+
" do not exist in chains' outputs: " +
74+
chainOutputs.joinToString { "{$it}" }
75+
)
76+
}
77+
78+
private fun Raise<Chain.InvalidInputs>.validateInputsOverlapping(
79+
sequenceInputs: List<String>,
80+
chainOutputs: List<String>
81+
): Unit =
82+
ensure(sequenceInputs.isNotEmpty() && sequenceInputs.all { it !in chainOutputs }) {
83+
Chain.InvalidInputs("The provided inputs: " +
84+
sequenceInputs.joinToString { "{$it}" } +
85+
" overlap with chain's outputs: " +
86+
chainOutputs.joinToString { "{$it}" }
87+
88+
)
89+
}
90+
91+

kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
package com.xebia.functional.chains
22

33
import arrow.core.Either
4-
import arrow.core.raise.Raise
5-
import arrow.core.raise.either
6-
import arrow.core.raise.ensure
7-
import arrow.core.raise.recover
8-
import arrow.core.raise.zipOrAccumulate
4+
import arrow.core.raise.*
95

106
fun Raise<Chain.Error>.SimpleSequenceChain(
117
chains: List<Chain>,
@@ -20,7 +16,7 @@ class SimpleSequenceChain private constructor(
2016
private val inputKey: String,
2117
private val outputKey: String,
2218
chainOutput: Chain.ChainOutput
23-
) : SequenceChain {
19+
) : SequenceChain(chains, listOf(inputKey), listOf(outputKey), chainOutput) {
2420

2521
override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)
2622

@@ -40,14 +36,14 @@ class SimpleSequenceChain private constructor(
4036
inputKey: String,
4137
outputKey: String,
4238
chainOutput: Chain.ChainOutput
43-
): Either<SequenceChain.InvalidKeys, SimpleSequenceChain> =
39+
): Either<InvalidKeys, SimpleSequenceChain> =
4440
either {
4541
val mappedChains: List<Chain> = chains.map { chain ->
4642
recover({
4743
zipOrAccumulate(
4844
{ validateInputKeys(chain.config.inputKeys) },
4945
{ validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain }
50-
}) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) }
46+
}) { raise(InvalidKeys(reason = it.joinToString(transform = Chain.Error::reason))) }
5147
}
5248
SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput)
5349
}
@@ -65,3 +61,4 @@ private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>)
6561
Chain.InvalidInputs("The expected inputs are more than one: " +
6662
inputKeys.joinToString(", ") { "{$it}" })
6763
}
64+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package com.xebia.functional.chains
2+
3+
import arrow.core.raise.either
4+
import io.kotest.assertions.arrow.core.shouldBeLeft
5+
import io.kotest.assertions.arrow.core.shouldBeRight
6+
import io.kotest.core.spec.style.StringSpec
7+
8+
class SequenceChainSpec : StringSpec({
9+
"SequenceChain should return a prediction with one Chain" {
10+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
11+
val chains = listOf(chain1)
12+
13+
either {
14+
val sc = SequenceChain(
15+
chains = chains,
16+
inputVariables = listOf("foo"),
17+
outputVariables = listOf("bar"),
18+
chainOutput = Chain.ChainOutput.InputAndOutput
19+
)
20+
sc.run(mapOf("foo" to "123")).bind()
21+
} shouldBeRight mapOf("foo" to "123", "bar" to "123dr")
22+
}
23+
24+
"SequenceChain should return a prediction on a single input chain" {
25+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
26+
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
27+
val chains = listOf(chain1, chain2)
28+
29+
either {
30+
val sc = SequenceChain(
31+
chains = chains,
32+
inputVariables = listOf("foo"),
33+
outputVariables = listOf("baz"),
34+
chainOutput = Chain.ChainOutput.InputAndOutput
35+
)
36+
sc.run(mapOf("foo" to "123")).bind()
37+
} shouldBeRight mapOf("foo" to "123", "baz" to "123drdr")
38+
}
39+
40+
"SequenceChain should return a prediction on a multiple input chain" {
41+
val chain1 = FakeChain(inputVariables = setOf("foo", "test"), outputVariables = setOf("bar"))
42+
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
43+
val chains = listOf(chain1, chain2)
44+
45+
either {
46+
val sc = SequenceChain(
47+
chains = chains,
48+
inputVariables = listOf("foo", "test"),
49+
outputVariables = listOf("baz"),
50+
chainOutput = Chain.ChainOutput.InputAndOutput
51+
)
52+
sc.run(mapOf("foo" to "123", "test" to "456")).bind()
53+
} shouldBeRight mapOf("foo" to "123", "test" to "456", "baz" to "123456dr123dr")
54+
}
55+
56+
"SequenceChain should return a prediction on a multiple output chain" {
57+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
58+
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
59+
val chains = listOf(chain1, chain2)
60+
61+
either {
62+
val sc = SequenceChain(
63+
chains = chains,
64+
inputVariables = listOf("foo"),
65+
outputVariables = listOf("baz"),
66+
chainOutput = Chain.ChainOutput.InputAndOutput
67+
)
68+
sc.run(mapOf("foo" to "123")).bind()
69+
} shouldBeRight mapOf("foo" to "123", "baz" to "123dr123dr")
70+
}
71+
72+
"SequenceChain should fail when input variables are missing" {
73+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
74+
val chain2 = FakeChain(inputVariables = setOf("bar", "test"), outputVariables = setOf("baz"))
75+
val chains = listOf(chain1, chain2)
76+
77+
either {
78+
val sc = SequenceChain(
79+
chains = chains,
80+
inputVariables = listOf("foo"),
81+
outputVariables = listOf("baz"),
82+
chainOutput = Chain.ChainOutput.InputAndOutput
83+
)
84+
sc.run(mapOf("foo" to "123")).bind()
85+
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {foo}, {bar} do not match with chain's inputs: {bar}, {test}")
86+
}
87+
88+
"SequenceChain should fail when output variables are missing" {
89+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
90+
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
91+
val chains = listOf(chain1, chain2)
92+
93+
either {
94+
val sc = SequenceChain.either(
95+
chains = chains,
96+
inputVariables = listOf("foo"),
97+
outputVariables = listOf("test"),
98+
chainOutput = Chain.ChainOutput.InputAndOutput
99+
).bind()
100+
sc.run(mapOf("foo" to "123")).bind()
101+
} shouldBeLeft SequenceChain.InvalidKeys("The provided outputs: {test} do not exist in chains' outputs: {bar}, {baz}")
102+
}
103+
104+
"SequenceChain should fail when input variables are overlapping" {
105+
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
106+
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
107+
val chains = listOf(chain1, chain2)
108+
109+
either {
110+
val sc = SequenceChain.either(
111+
chains = chains,
112+
inputVariables = listOf("foo", "test"),
113+
outputVariables = listOf("baz"),
114+
chainOutput = Chain.ChainOutput.InputAndOutput
115+
).bind()
116+
sc.run(mapOf("foo" to "123")).bind()
117+
} shouldBeLeft SequenceChain.InvalidKeys("The provided inputs: {foo}, {test} overlap with chain's outputs: {bar}, {test}, {baz}")
118+
}
119+
})

scala/build.gradle.kts

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ plugins {
77
}
88

99
dependencies {
10-
implementation(projects.langchain4kKotlin)
10+
//implementation(projects.langchain4kKotlin)
1111
implementation(libs.kotlinx.coroutines)
1212
implementation(libs.ciris.core)
1313
implementation(libs.ciris.refined)

scala/src/test/scala/com/xebia/functional/chains/SequentialChainSpec.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class SequentialChainSpec extends CatsEffectSuite:
8989
interceptIO[MissingOutputVariablesError](output)
9090
}
9191

92-
test("Test SequentialChainruns when valid outputs are specified.") {
92+
test("Test SequentialChain runs when valid outputs are specified.") {
9393
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar"))
9494
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
9595
val chains = NonEmptySeq(chain1, Seq(chain2))
@@ -105,7 +105,7 @@ class SequentialChainSpec extends CatsEffectSuite:
105105
assertIO(output, expectedOutput)
106106
}
107107

108-
test("Test SequentialChainruns error is raised when input variables are overlapping.") {
108+
test("Test SequentialChain runs error is raised when input variables are overlapping.") {
109109
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar", "test"))
110110
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
111111
val chains = NonEmptySeq(chain1, Seq(chain2))

0 commit comments

Comments
 (0)