Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SequenceChain Implementation #29

Merged
merged 7 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ interface Chain {

data class InvalidInputs(override val reason: String): Error(reason)

data class InvalidOutputs(override val reason: String): Error(reason)

data class Config(
val inputKeys: Set<String>,
val outputKeys: Set<String>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,90 @@
package com.xebia.functional.chains

interface SequenceChain : Chain {
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
data class InvalidKeys(override val reason: String): Chain.Error(reason)
import arrow.core.Either
import arrow.core.flatten
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.Raise
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate

fun Raise<Chain.Error>.SequenceChain(
chains: List<Chain>,
inputVariables: List<String>,
outputVariables: List<String>,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): SequenceChain =
SequenceChain.either(chains, inputVariables, outputVariables, chainOutput).bind()

open class SequenceChain(
Yawolf marked this conversation as resolved.
Show resolved Hide resolved
private val chains: List<Chain>,
private val inputVariables: List<String>,
private val outputVariables: List<String>,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
) : Chain {
data class InvalidOutputs(override val reason: String) : Chain.Error(reason)
data class InvalidKeys(override val reason: String) : Chain.Error(reason)

override val config = Chain.Config(inputVariables.toSet(), outputVariables.toSet(), chainOutput)

private val outputs = when (chainOutput) {
Chain.ChainOutput.OnlyOutput -> outputVariables
Chain.ChainOutput.InputAndOutput -> outputVariables.plus(inputVariables)
}

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
either {
val chainRes = chains.fold(inputs) { inputs0, chain ->
chain.run(inputs0).map { inputs0.plus(it) }.bind()
Yawolf marked this conversation as resolved.
Show resolved Hide resolved
}
chainRes.filter { it.key in outputs }
}

companion object {
fun either(
chains: List<Chain>,
inputVariables: List<String>,
outputVariables: List<String>,
chainOutput: Chain.ChainOutput
): Either<InvalidKeys, SequenceChain> =
either {
val allOutputs = chains.map { it.config.outputKeys }.toSet().flatten()
val mappedChains: List<Chain> = chains.map { chain ->
recover({
zipOrAccumulate(
{ validateSequenceOutputs(outputVariables, allOutputs) },
{ validateInputsOverlapping(inputVariables, allOutputs) },
) { _, _ -> chain }
}) { raise(InvalidKeys(reason = it.joinToString(transform = Chain.Error::reason))) }
}
Yawolf marked this conversation as resolved.
Show resolved Hide resolved
SequenceChain(mappedChains, inputVariables, outputVariables, chainOutput)
}
}
}

private fun Raise<Chain.InvalidOutputs>.validateSequenceOutputs(
sequenceOutputs: List<String>,
chainOutputs: List<String>
): Unit =
ensure(sequenceOutputs.isNotEmpty() && sequenceOutputs.all { it in chainOutputs }) {
Chain.InvalidOutputs("The provided outputs: " +
sequenceOutputs.joinToString(", ") { "{$it}" } +
" do not exist in chains' outputs: " +
chainOutputs.joinToString { "{$it}" }
)
}

private fun Raise<Chain.InvalidInputs>.validateInputsOverlapping(
sequenceInputs: List<String>,
chainOutputs: List<String>
): Unit =
ensure(sequenceInputs.isNotEmpty() && sequenceInputs.all { it !in chainOutputs }) {
Chain.InvalidInputs("The provided inputs: " +
sequenceInputs.joinToString { "{$it}" } +
" overlap with chain's outputs: " +
chainOutputs.joinToString { "{$it}" }

)
}


Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate
import arrow.core.raise.*

fun Raise<Chain.Error>.SimpleSequenceChain(
chains: List<Chain>,
Expand All @@ -20,7 +16,7 @@ class SimpleSequenceChain private constructor(
private val inputKey: String,
private val outputKey: String,
chainOutput: Chain.ChainOutput
) : SequenceChain {
) : SequenceChain(chains, listOf(inputKey), listOf(outputKey), chainOutput) {

override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)

Expand All @@ -40,14 +36,14 @@ class SimpleSequenceChain private constructor(
inputKey: String,
outputKey: String,
chainOutput: Chain.ChainOutput
): Either<SequenceChain.InvalidKeys, SimpleSequenceChain> =
): Either<InvalidKeys, SimpleSequenceChain> =
either {
val mappedChains: List<Chain> = chains.map { chain ->
Yawolf marked this conversation as resolved.
Show resolved Hide resolved
recover({
zipOrAccumulate(
{ validateInputKeys(chain.config.inputKeys) },
{ validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain }
}) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) }
}) { raise(InvalidKeys(reason = it.joinToString(transform = Chain.Error::reason))) }
}
SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput)
}
Expand All @@ -65,3 +61,4 @@ private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>)
Chain.InvalidInputs("The expected inputs are more than one: " +
inputKeys.joinToString(", ") { "{$it}" })
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package com.xebia.functional.chains

import arrow.core.raise.either
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
import io.kotest.core.spec.style.StringSpec

class SequenceChainSpec : StringSpec({
"SequenceChain should return a prediction with one Chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chains = listOf(chain1)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("bar"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "bar" to "123dr")
}

"SequenceChain should return a prediction on a single input chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "baz" to "123drdr")
}

"SequenceChain should return a prediction on a multiple input chain" {
val chain1 = FakeChain(inputVariables = setOf("foo", "test"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo", "test"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123", "test" to "456")).bind()
} shouldBeRight mapOf("foo" to "123", "test" to "456", "baz" to "123456dr123dr")
}

"SequenceChain should return a prediction on a multiple output chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "baz" to "123dr123dr")
}

"SequenceChain should fail when input variables are missing" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar", "test"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {foo}, {bar} do not match with chain's inputs: {bar}, {test}")
}

"SequenceChain should fail when output variables are missing" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain.either(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("test"),
chainOutput = Chain.ChainOutput.InputAndOutput
).bind()
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The provided outputs: {test} do not exist in chains' outputs: {bar}, {baz}")
}

"SequenceChain should fail when input variables are overlapping" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain.either(
chains = chains,
inputVariables = listOf("foo", "test"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
).bind()
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The provided inputs: {foo}, {test} overlap with chain's outputs: {bar}, {test}, {baz}")
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SequentialChainSpec extends CatsEffectSuite:
interceptIO[MissingOutputVariablesError](output)
}

test("Test SequentialChainruns when valid outputs are specified.") {
test("Test SequentialChain runs when valid outputs are specified.") {
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar"))
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
val chains = NonEmptySeq(chain1, Seq(chain2))
Expand All @@ -105,7 +105,7 @@ class SequentialChainSpec extends CatsEffectSuite:
assertIO(output, expectedOutput)
}

test("Test SequentialChainruns error is raised when input variables are overlapping.") {
test("Test SequentialChain runs error is raised when input variables are overlapping.") {
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar", "test"))
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
val chains = NonEmptySeq(chain1, Seq(chain2))
Expand Down