Skip to content

Commit

Permalink
feat: apply suggestion of representing chain output model as an enum …
Browse files Browse the repository at this point in the history
…class
  • Loading branch information
realdavidvega committed May 2, 2023
1 parent e23ba83 commit 38fd29c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 21 deletions.
10 changes: 8 additions & 2 deletions src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import arrow.core.raise.ensure

interface Chain {

enum class ChainOutput { InputAndOutput, OnlyOutput }

sealed class Error(open val reason: String)

data class InvalidInputs(override val reason: String): Error(reason)

data class Config(
val inputKeys: Set<String>,
val outputKeys: Set<String>,
val returnAll: Boolean = false
val chainOutput: ChainOutput = ChainOutput.OnlyOutput
) {
fun createInputs(
inputs: String
Expand Down Expand Up @@ -60,5 +63,8 @@ interface Chain {
private fun prepareOutputs(
inputs: Map<String, String>, outputs: Map<String, String>
): Map<String, String> =
if (config.returnAll) inputs + outputs else outputs
when (config.chainOutput) {
ChainOutput.InputAndOutput -> inputs + outputs
ChainOutput.OnlyOutput -> outputs
}
}
4 changes: 2 additions & 2 deletions src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ suspend fun LLMChain(
echo: Boolean,
n: Int,
temperature: Double,
returnAll: Boolean = false
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): Chain = object : Chain {

override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), returnAll)
override val config: Chain.Config = Chain.Config(promptTemplate.inputKeys.toSet(), setOf("answer"), chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<Chain.InvalidInputs, Map<String, String>> =
either {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import arrow.core.raise.ensureNotNull
import arrow.core.raise.zipOrAccumulate

fun Raise<Chain.Error>.SimpleSequentialChain(
chains: List<Chain>, inputKey: String = "input", outputKey: String = "output", returnAll: Boolean = false
chains: List<Chain>,
inputKey: String = "input",
outputKey: String = "output",
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): SimpleSequentialChain =
SimpleSequentialChain.either(chains, inputKey, outputKey, returnAll).bind()
SimpleSequentialChain.either(chains, inputKey, outputKey, chainOutput).bind()

class SimpleSequentialChain private constructor(
private val chains: List<Chain>,
private val inputKey: String,
private val outputKey: String,
returnAll: Boolean
chainOutput: Chain.ChainOutput
) : SequenceChain {

override val config = Chain.Config(setOf(inputKey), setOf(outputKey), returnAll)
override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
either {
Expand All @@ -34,7 +37,7 @@ class SimpleSequentialChain private constructor(

companion object {
fun either(
chains: List<Chain>, inputKey: String, outputKey: String, returnAll: Boolean
chains: List<Chain>, inputKey: String, outputKey: String, chainOutput: Chain.ChainOutput
): Either<SequenceChain.InvalidKeys, SimpleSequentialChain> =
either {
chains.map { chain ->
Expand All @@ -47,7 +50,7 @@ class SimpleSequentialChain private constructor(
}
}.mapLeft {
SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))
}.map { SimpleSequentialChain(chains, inputKey, outputKey, returnAll) }
}.map { SimpleSequentialChain(chains, inputKey, outputKey, chainOutput) }
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.xebia.functional.chains

import arrow.core.raise.either
import com.xebia.functional.llm.openai.*
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.EmbeddingResult
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
Expand All @@ -21,7 +25,7 @@ class LLMChainSpec : StringSpec({
val template = "Tell me {foo}."
either {
val prompt = PromptTemplate(template, listOf("foo"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
chain.run("a joke").bind()
} shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes")
}
Expand All @@ -30,7 +34,7 @@ class LLMChainSpec : StringSpec({
val template = "My name is {name} and I'm {age} years old"
either {
val prompt = PromptTemplate(template, listOf("name", "age"))
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true)
val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput)
chain.run(mapOf("age" to "28", "name" to "foo")).bind()
} shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SimpleSequentialChainSpec : StringSpec({
val chains = listOf(chain1)

either {
val ssc = SimpleSequentialChain(chains = chains, returnAll = true)
val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput)
ssc.run(mapOf("input" to "123")).bind()
} shouldBeRight mapOf("input" to "123", "output" to "123dr")
}
Expand All @@ -24,7 +24,7 @@ class SimpleSequentialChainSpec : StringSpec({
val chains = listOf(chain1, chain2, chain3)

either {
val ssc = SimpleSequentialChain(chains = chains, returnAll = true)
val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput)
ssc.run(mapOf("input" to "123")).bind()
} shouldBeRight mapOf("input" to "123", "output" to "123drdrdr")
}
Expand All @@ -35,7 +35,7 @@ class SimpleSequentialChainSpec : StringSpec({
val chains = listOf(chain1, chain2)

either {
val ssc = SimpleSequentialChain(chains = chains, returnAll = true)
val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput)
ssc.run(mapOf("input" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {bar}, {foo}")
}
Expand All @@ -46,7 +46,7 @@ class SimpleSequentialChainSpec : StringSpec({
val chains = listOf(chain1, chain2)

either {
val ssc = SimpleSequentialChain(chains = chains, returnAll = true)
val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput)
ssc.run(mapOf("input" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The expected outputs are more than one: {bar}, {foo}")
}
Expand All @@ -57,17 +57,17 @@ class SimpleSequentialChainSpec : StringSpec({
val chains = listOf(chain1, chain2)

either {
val ssc = SimpleSequentialChain(chains = chains, returnAll = true)
val ssc = SimpleSequentialChain(chains = chains, chainOutput = Chain.ChainOutput.InputAndOutput)
ssc.run(mapOf("input" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, The expected outputs are more than one: {bar}, {foo}")
} shouldBeLeft SequenceChain.InvalidKeys("The expected inputs are more than one: {foo}, {bar}, " +
"The expected outputs are more than one: {bar}, {foo}")
}
})

data class FakeChain(private val inputVariables: Set<String>, private val outputVariables: Set<String>) : Chain {
override val config: Chain.Config = Chain.Config(
inputKeys = inputVariables,
outputKeys = outputVariables,
returnAll = false
outputKeys = outputVariables
)

override suspend fun call(inputs: Map<String, String>): Either<Chain.InvalidInputs, Map<String, String>> =
Expand Down

0 comments on commit 38fd29c

Please sign in to comment.