@@ -3,51 +3,78 @@ package com.xebia.functional.chains
3
3
import arrow.core.Either
4
4
import arrow.core.raise.either
5
5
import com.xebia.functional.AIError
6
- import com.xebia.functional.AIError.Chain.InvalidInputs
7
6
import com.xebia.functional.llm.openai.*
7
+ import com.xebia.functional.llm.openai.LLMModel.Kind.*
8
8
import com.xebia.functional.prompt.PromptTemplate
9
9
10
10
@Suppress(" LongParameterList" )
11
11
suspend fun LLMChain (
12
- llm : OpenAIClient ,
13
- promptTemplate : PromptTemplate <String >,
14
- llmModel : String = "gpt-3.5-turbo",
15
- user : String = "testing",
16
- n : Int = 1,
17
- temperature : Double = 0.0,
18
- outputVariable : String ,
19
- chainOutput : Chain .ChainOutput = Chain .ChainOutput .OnlyOutput
12
+ llm : OpenAIClient ,
13
+ promptTemplate : PromptTemplate <String >,
14
+ model : LLMModel ,
15
+ user : String = "testing",
16
+ echo : Boolean = false,
17
+ n : Int = 1,
18
+ temperature : Double = 0.0,
19
+ outputVariable : String ,
20
+ chainOutput : Chain .ChainOutput = Chain .ChainOutput .OnlyOutput
20
21
): Chain = object : Chain {
21
22
22
- private val inputKeys: Set <String > = promptTemplate.inputKeys.toSet()
23
- private val outputKeys: Set <String > = setOf (outputVariable)
24
-
25
- override val config: Chain .Config = Chain .Config (inputKeys, outputKeys, chainOutput)
26
-
27
- override suspend fun call (inputs : Map <String , String >): Either <InvalidInputs , Map <String , String >> =
28
- either {
29
- val prompt = promptTemplate.format(inputs)
30
-
31
- val request = ChatCompletionRequest (
32
- model = llmModel,
33
- user = user,
34
- messages = listOf (
35
- Message (
36
- role = Role .system.name,
37
- content = prompt
38
- )
39
- ),
40
- n = n,
41
- temperature = temperature,
42
- maxTokens = 256
43
- )
44
-
45
- val completions = llm.createChatCompletion(request)
46
- formatOutput(completions.choices)
23
+ private val inputKeys: Set <String > = promptTemplate.inputKeys.toSet()
24
+ private val outputKeys: Set <String > = setOf (outputVariable)
25
+
26
+ override val config: Chain .Config = Chain .Config (inputKeys, outputKeys, chainOutput)
27
+
28
+ override suspend fun call (inputs : Map <String , String >): Either <AIError .Chain .InvalidInputs , Map <String , String >> =
29
+ either {
30
+ val prompt = promptTemplate.format(inputs)
31
+ when (model.kind) {
32
+ Completion -> callCompletionEndpoint(prompt)
33
+ Chat -> callChatEndpoint(prompt)
34
+ }
35
+ }
36
+
37
+ private suspend fun callCompletionEndpoint (prompt : String ): Map <String , String > {
38
+ val request = CompletionRequest (
39
+ model = model.name,
40
+ user = user,
41
+ prompt = prompt,
42
+ echo = echo,
43
+ n = n,
44
+ temperature = temperature,
45
+ maxTokens = 256
46
+ )
47
+
48
+ val completions = llm.createCompletion(request)
49
+ return formatCompletionOutput(completions)
47
50
}
48
51
49
- private fun formatOutput (completions : List <Choice >): Map <String , String > =
50
- config.outputKeys.associateWith {
51
- completions.joinToString(" , " ) { it.message.content }
52
+ private suspend fun callChatEndpoint (prompt : String ): Map <String , String > {
53
+ val request = ChatCompletionRequest (
54
+ model = model.name,
55
+ user = user,
56
+ messages = listOf (
57
+ Message (
58
+ Role .system.name,
59
+ prompt
60
+ )
61
+ ),
62
+ n = n,
63
+ temperature = temperature,
64
+ maxTokens = 256
65
+ )
66
+
67
+ val completions = llm.createChatCompletion(request)
68
+ return formatChatOutput(completions.choices)
52
69
}
70
+
71
+ private fun formatChatOutput (completions : List <Choice >): Map <String , String > =
72
+ config.outputKeys.associateWith {
73
+ completions.joinToString(" , " ) { it.message.content }
74
+ }
75
+
76
+ private fun formatCompletionOutput (completions : List <CompletionChoice >): Map <String , String > =
77
+ config.outputKeys.associateWith {
78
+ completions.joinToString(" , " ) { it.text }
79
+ }
53
80
}
0 commit comments