Skip to content

Commit bad8a94

Browse files
CU-865ca0fvw Pass LLMModel through all chains (#43)
Co-authored-by: franciscodr <[email protected]>
1 parent 80e8425 commit bad8a94

File tree

10 files changed

+243
-161
lines changed

10 files changed

+243
-161
lines changed

example/src/main/kotlin/com/xebia/functional/langchain4k/chain/Weather.kt

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.xebia.functional.chains.VectorQAChain
88
import com.xebia.functional.embeddings.OpenAIEmbeddings
99
import com.xebia.functional.env.OpenAIConfig
1010
import com.xebia.functional.llm.openai.KtorOpenAIClient
11+
import com.xebia.functional.llm.openai.LLMModel
1112
import com.xebia.functional.llm.openai.OpenAIClient
1213
import com.xebia.functional.tool.search
1314
import com.xebia.functional.tools.storeResults
@@ -49,6 +50,7 @@ private suspend fun getQuestionAnswer(
4950
val outputVariable = "answer"
5051
val chain = VectorQAChain(
5152
openAiClient,
53+
LLMModel.GPT_3_5_TURBO,
5254
vectorStore,
5355
numOfDocs,
5456
outputVariable

kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt

+5-7
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ import com.xebia.functional.auto.serialization.sample
1414
import com.xebia.functional.chains.VectorQAChain
1515
import com.xebia.functional.embeddings.OpenAIEmbeddings
1616
import com.xebia.functional.env.OpenAIConfig
17-
import com.xebia.functional.llm.openai.ChatCompletionRequest
18-
import com.xebia.functional.llm.openai.ChatCompletionResponse
19-
import com.xebia.functional.llm.openai.KtorOpenAIClient
20-
import com.xebia.functional.llm.openai.Message
21-
import com.xebia.functional.llm.openai.OpenAIClient
22-
import com.xebia.functional.llm.openai.Role
17+
import com.xebia.functional.llm.openai.*
2318
import com.xebia.functional.logTruncated
2419
import com.xebia.functional.tools.Agent
2520
import com.xebia.functional.tools.storeResults
@@ -193,9 +188,10 @@ class AIScope(
193188
prompt: String,
194189
serializationConfig: SerializationConfig<A>,
195190
maxAttempts: Int = 5,
191+
llmModel: LLMModel = LLMModel.GPT_3_5_TURBO,
196192
): A {
197193
logger.logTruncated("AI", "Solving objective: $prompt")
198-
val result = openAIChatCall(prompt, prompt, serializationConfig)
194+
val result = openAIChatCall(llmModel, prompt, prompt, serializationConfig)
199195
logger.logTruncated("AI", "Response: $result")
200196
return catch({
201197
json.decodeFromString(serializationConfig.deserializationStrategy, result)
@@ -209,6 +205,7 @@ class AIScope(
209205
}
210206

211207
private suspend fun Raise<AIError>.openAIChatCall(
208+
llmModel: LLMModel,
212209
question: String,
213210
promptWithContext: String,
214211
serializationConfig: SerializationConfig<*>,
@@ -220,6 +217,7 @@ class AIScope(
220217
val outputVariable = "answer"
221218
val chain = VectorQAChain(
222219
openAIClient,
220+
llmModel,
223221
vectorStore,
224222
numOfDocs,
225223
outputVariable

kotlin/src/commonMain/kotlin/com/xebia/functional/chains/CombineDocsChain.kt

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.xebia.functional.chains
33
import arrow.core.Either
44
import com.xebia.functional.AIError
55
import com.xebia.functional.Document
6+
import com.xebia.functional.llm.openai.LLMModel
67
import com.xebia.functional.llm.openai.OpenAIClient
78
import com.xebia.functional.prompt.PromptTemplate
89

@@ -14,6 +15,7 @@ interface CombineDocsChain : Chain {
1415
suspend fun CombineDocsChain(
1516
llm: OpenAIClient,
1617
promptTemplate: PromptTemplate<String>,
18+
llmModel: LLMModel,
1719
documents: List<Document>,
1820
documentVariableName: String,
1921
outputVariable: String,
@@ -34,6 +36,7 @@ suspend fun CombineDocsChain(
3436
val llmChain = LLMChain(
3537
llm,
3638
promptTemplate,
39+
llmModel,
3740
outputVariable = outputVariable,
3841
chainOutput = chainOutput
3942
)

kotlin/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt

+64-37
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,78 @@ package com.xebia.functional.chains
33
import arrow.core.Either
44
import arrow.core.raise.either
55
import com.xebia.functional.AIError
6-
import com.xebia.functional.AIError.Chain.InvalidInputs
76
import com.xebia.functional.llm.openai.*
7+
import com.xebia.functional.llm.openai.LLMModel.Kind.*
88
import com.xebia.functional.prompt.PromptTemplate
99

1010
@Suppress("LongParameterList")
1111
suspend fun LLMChain(
12-
llm: OpenAIClient,
13-
promptTemplate: PromptTemplate<String>,
14-
llmModel: String = "gpt-3.5-turbo",
15-
user: String = "testing",
16-
n: Int = 1,
17-
temperature: Double = 0.0,
18-
outputVariable: String,
19-
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
12+
llm: OpenAIClient,
13+
promptTemplate: PromptTemplate<String>,
14+
model: LLMModel,
15+
user: String = "testing",
16+
echo: Boolean = false,
17+
n: Int = 1,
18+
temperature: Double = 0.0,
19+
outputVariable: String,
20+
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
2021
): Chain = object : Chain {
2122

22-
private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet()
23-
private val outputKeys: Set<String> = setOf(outputVariable)
24-
25-
override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)
26-
27-
override suspend fun call(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> =
28-
either {
29-
val prompt = promptTemplate.format(inputs)
30-
31-
val request = ChatCompletionRequest(
32-
model = llmModel,
33-
user = user,
34-
messages = listOf(
35-
Message(
36-
role = Role.system.name,
37-
content = prompt
38-
)
39-
),
40-
n = n,
41-
temperature = temperature,
42-
maxTokens = 256
43-
)
44-
45-
val completions = llm.createChatCompletion(request)
46-
formatOutput(completions.choices)
23+
private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet()
24+
private val outputKeys: Set<String> = setOf(outputVariable)
25+
26+
override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)
27+
28+
override suspend fun call(inputs: Map<String, String>): Either<AIError.Chain.InvalidInputs, Map<String, String>> =
29+
either {
30+
val prompt = promptTemplate.format(inputs)
31+
when (model.kind) {
32+
Completion -> callCompletionEndpoint(prompt)
33+
Chat -> callChatEndpoint(prompt)
34+
}
35+
}
36+
37+
private suspend fun callCompletionEndpoint(prompt: String): Map<String, String> {
38+
val request = CompletionRequest(
39+
model = model.name,
40+
user = user,
41+
prompt = prompt,
42+
echo = echo,
43+
n = n,
44+
temperature = temperature,
45+
maxTokens = 256
46+
)
47+
48+
val completions = llm.createCompletion(request)
49+
return formatCompletionOutput(completions)
4750
}
4851

49-
private fun formatOutput(completions: List<Choice>): Map<String, String> =
50-
config.outputKeys.associateWith {
51-
completions.joinToString(", ") { it.message.content }
52+
private suspend fun callChatEndpoint(prompt: String): Map<String, String> {
53+
val request = ChatCompletionRequest(
54+
model = model.name,
55+
user = user,
56+
messages = listOf(
57+
Message(
58+
Role.system.name,
59+
prompt
60+
)
61+
),
62+
n = n,
63+
temperature = temperature,
64+
maxTokens = 256
65+
)
66+
67+
val completions = llm.createChatCompletion(request)
68+
return formatChatOutput(completions.choices)
5269
}
70+
71+
private fun formatChatOutput(completions: List<Choice>): Map<String, String> =
72+
config.outputKeys.associateWith {
73+
completions.joinToString(", ") { it.message.content }
74+
}
75+
76+
private fun formatCompletionOutput(completions: List<CompletionChoice>): Map<String, String> =
77+
config.outputKeys.associateWith {
78+
completions.joinToString(", ") { it.text }
79+
}
5380
}

kotlin/src/commonMain/kotlin/com/xebia/functional/chains/VectorQAChain.kt

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import arrow.core.raise.either
66
import arrow.core.raise.recover
77
import com.xebia.functional.AIError
88
import com.xebia.functional.Document
9+
import com.xebia.functional.llm.openai.LLMModel
910
import com.xebia.functional.llm.openai.OpenAIClient
1011
import com.xebia.functional.prompt.PromptTemplate
1112
import com.xebia.functional.vectorstores.VectorStore
@@ -18,6 +19,7 @@ interface VectorQAChain : Chain {
1819
@Suppress("LongParameterList")
1920
suspend fun VectorQAChain(
2021
llm: OpenAIClient,
22+
llmModel: LLMModel,
2123
vectorStore: VectorStore,
2224
numOfDocs: Int,
2325
outputVariable: String,
@@ -42,6 +44,7 @@ suspend fun VectorQAChain(
4244
val chain = CombineDocsChain(
4345
llm,
4446
promptTemplate,
47+
llmModel,
4548
documents,
4649
documentVariableName,
4750
outputVariable,

kotlin/src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt

+21
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,24 @@ data class Usage(
119119
@SerialName("completion_tokens") val completionTokens: Long? = null,
120120
@SerialName("total_tokens") val totalTokens: Long
121121
)
122+
123+
data class LLMModel(
124+
val name: String,
125+
val kind : Kind
126+
) {
127+
enum class Kind {
128+
Completion, Chat
129+
}
130+
companion object {
131+
val GPT_4 = LLMModel("gpt-4", Kind.Chat)
132+
val GPT_4_0314 = LLMModel("gpt-4-0314", Kind.Chat)
133+
val GPT_4_32K = LLMModel("gpt-4-32k", Kind.Chat)
134+
val GPT_3_5_TURBO = LLMModel("gpt3.5-turbo", Kind.Chat)
135+
val GPT_3_5_TURBO_0301 = LLMModel("gpt3.5-turbo-0301", Kind.Chat)
136+
val TEXT_DAVINCI_003 = LLMModel("text-davinci-003", Kind.Completion)
137+
val TEXT_DAVINCI_002 = LLMModel("text-davinci-002", Kind.Completion)
138+
val TEXT_CURIE_001 = LLMModel("text-curie-001", Kind.Completion)
139+
val TEXT_BABBAGE_001 = LLMModel("text-babbage-001", Kind.Completion)
140+
val TEXT_ADA_001 = LLMModel("text-ada-001", Kind.Completion)
141+
}
142+
}

kotlin/src/commonTest/kotlin/com/xebia/functional/chains/ChainTestUtils.kt

+40-22
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,51 @@
11
package com.xebia.functional.chains
22

3-
import com.xebia.functional.llm.openai.ChatCompletionRequest
4-
import com.xebia.functional.llm.openai.ChatCompletionResponse
5-
import com.xebia.functional.llm.openai.CompletionChoice
6-
import com.xebia.functional.llm.openai.CompletionRequest
7-
import com.xebia.functional.llm.openai.EmbeddingRequest
8-
import com.xebia.functional.llm.openai.EmbeddingResult
9-
import com.xebia.functional.llm.openai.OpenAIClient
3+
import com.xebia.functional.llm.openai.*
104

115
val testLLM = object : OpenAIClient {
12-
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
13-
when (request.prompt) {
14-
"Tell me a joke." ->
15-
listOf(CompletionChoice("I'm not good at jokes", 1, finishReason = "foo"))
6+
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
7+
when (request.prompt) {
8+
"Tell me a joke." ->
9+
listOf(CompletionChoice("I'm not good at jokes", 1, finishReason = "foo"))
1610

17-
"My name is foo and I'm 28 years old" ->
18-
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, finishReason = "foo"))
11+
"My name is foo and I'm 28 years old" ->
12+
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, finishReason = "foo"))
1913

20-
testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
21-
testTemplateInputsFormatted -> listOf(CompletionChoice("Two inputs, right?", 1, finishReason = "foo"))
22-
testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
23-
else -> listOf(CompletionChoice("foo", 1, finishReason = "bar"))
24-
}
14+
testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
15+
testTemplateInputsFormatted -> listOf(CompletionChoice("Two inputs, right?", 1, finishReason = "foo"))
16+
testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
17+
else -> listOf(CompletionChoice("foo", 1, finishReason = "bar"))
18+
}
2519

26-
override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
27-
TODO()
20+
override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
21+
when (request.messages.firstOrNull()?.content) {
22+
"Tell me a joke." -> fakeChatCompletion("I'm not good at jokes")
23+
"My name is foo and I'm 28 years old" -> fakeChatCompletion("Hello there! Nice to meet you foo")
24+
testTemplateFormatted -> fakeChatCompletion("I don't know")
25+
testTemplateInputsFormatted -> fakeChatCompletion("Two inputs, right?")
26+
testQATemplateFormatted -> fakeChatCompletion("I don't know")
27+
else -> fakeChatCompletion("foo")
28+
}
2829

29-
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
30-
TODO()
30+
private fun fakeChatCompletion(message: String): ChatCompletionResponse =
31+
ChatCompletionResponse(
32+
id = "foo",
33+
`object` = "foo",
34+
created = 1,
35+
model = "foo",
36+
usage = Usage(1, 1, 1),
37+
choices = listOf(
38+
Choice(
39+
Message(
40+
Role.system.name,
41+
message
42+
), "foo", index = 0
43+
)
44+
)
45+
)
46+
47+
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
48+
TODO()
3149
}
3250

3351
val testContext = """foo foo foo

0 commit comments

Comments
 (0)