Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-865ca0fvw Pass LLMModel through all chains #43

Merged
merged 6 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.xebia.functional.chains.VectorQAChain
import com.xebia.functional.embeddings.OpenAIEmbeddings
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.KtorOpenAIClient
import com.xebia.functional.llm.openai.LLMModel
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.tool.search
import com.xebia.functional.vectorstores.LocalVectorStore
Expand Down Expand Up @@ -54,6 +55,7 @@ private suspend fun getQuestionAnswer(
val outputVariable = "answer"
val chain = VectorQAChain(
openAiClient,
LLMModel.GPT_3_5_TURBO,
vectorStore,
numOfDocs,
outputVariable
Expand Down
12 changes: 5 additions & 7 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/auto/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ import com.xebia.functional.auto.serialization.sample
import com.xebia.functional.chains.VectorQAChain
import com.xebia.functional.embeddings.OpenAIEmbeddings
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.ChatCompletionRequest
import com.xebia.functional.llm.openai.ChatCompletionResponse
import com.xebia.functional.llm.openai.KtorOpenAIClient
import com.xebia.functional.llm.openai.Message
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.Role
import com.xebia.functional.llm.openai.*
import com.xebia.functional.logTruncated
import com.xebia.functional.tools.Tool
import com.xebia.functional.vectorstores.LocalVectorStore
Expand Down Expand Up @@ -98,9 +93,10 @@ class AIScope(
prompt: String,
serializationConfig: SerializationConfig<A>,
maxAttempts: Int = 5,
llmModel: LLMModel = LLMModel.GPT_3_5_TURBO,
): A {
logger.logTruncated("AI", "Solving objective: $prompt")
val result = openAIChatCall(prompt, prompt, serializationConfig)
val result = openAIChatCall(llmModel, prompt, prompt, serializationConfig)
logger.logTruncated("AI", "Response: $result")
return catch({
json.decodeFromString(serializationConfig.deserializationStrategy, result)
Expand Down Expand Up @@ -143,6 +139,7 @@ class AIScope(
}

private suspend fun Raise<AIError>.openAIChatCall(
llmModel: LLMModel,
question: String,
promptWithContext: String,
serializationConfig: SerializationConfig<*>,
Expand All @@ -156,6 +153,7 @@ class AIScope(
val outputVariable = "answer"
val chain = VectorQAChain(
openAIClient,
llmModel,
vectorStore,
numOfDocs,
outputVariable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.xebia.functional.chains
import arrow.core.Either
import com.xebia.functional.AIError
import com.xebia.functional.Document
import com.xebia.functional.llm.openai.LLMModel
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate

Expand All @@ -14,6 +15,7 @@ interface CombineDocsChain : Chain {
suspend fun CombineDocsChain(
llm: OpenAIClient,
promptTemplate: PromptTemplate<String>,
llmModel: LLMModel,
documents: List<Document>,
documentVariableName: String,
outputVariable: String,
Expand All @@ -34,6 +36,7 @@ suspend fun CombineDocsChain(
val llmChain = LLMChain(
llm,
promptTemplate,
llmModel,
outputVariable = outputVariable,
chainOutput = chainOutput
)
Expand Down
101 changes: 64 additions & 37 deletions kotlin/src/commonMain/kotlin/com/xebia/functional/chains/LLMChain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,78 @@ package com.xebia.functional.chains
import arrow.core.Either
import arrow.core.raise.either
import com.xebia.functional.AIError
import com.xebia.functional.AIError.Chain.InvalidInputs
import com.xebia.functional.llm.openai.*
import com.xebia.functional.llm.openai.LLMModel.Kind.*
import com.xebia.functional.prompt.PromptTemplate

@Suppress("LongParameterList")
suspend fun LLMChain(
llm: OpenAIClient,
promptTemplate: PromptTemplate<String>,
llmModel: String = "gpt-3.5-turbo",
user: String = "testing",
n: Int = 1,
temperature: Double = 0.0,
outputVariable: String,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
llm: OpenAIClient,
promptTemplate: PromptTemplate<String>,
model: LLMModel,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
outputVariable: String,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): Chain = object : Chain {

private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Quite some unrelated formatting changes. We need to fix spotless, I can raise a PR after this one gets merged.

private val outputKeys: Set<String> = setOf(outputVariable)

override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<InvalidInputs, Map<String, String>> =
either {
val prompt = promptTemplate.format(inputs)

val request = ChatCompletionRequest(
model = llmModel,
user = user,
messages = listOf(
Message(
role = Role.system.name,
content = prompt
)
),
n = n,
temperature = temperature,
maxTokens = 256
)

val completions = llm.createChatCompletion(request)
formatOutput(completions.choices)
private val inputKeys: Set<String> = promptTemplate.inputKeys.toSet()
private val outputKeys: Set<String> = setOf(outputVariable)

override val config: Chain.Config = Chain.Config(inputKeys, outputKeys, chainOutput)

override suspend fun call(inputs: Map<String, String>): Either<AIError.Chain.InvalidInputs, Map<String, String>> =
either {
val prompt = promptTemplate.format(inputs)
when (model.kind) {
Completion -> callCompletionEndpoint(prompt)
Chat -> callChatEndpoint(prompt)
}
}

private suspend fun callCompletionEndpoint(prompt: String): Map<String, String> {
val request = CompletionRequest(
model = model.name,
user = user,
prompt = prompt,
echo = echo,
n = n,
temperature = temperature,
maxTokens = 256
)

val completions = llm.createCompletion(request)
return formatCompletionOutput(completions)
}

private fun formatOutput(completions: List<Choice>): Map<String, String> =
config.outputKeys.associateWith {
completions.joinToString(", ") { it.message.content }
private suspend fun callChatEndpoint(prompt: String): Map<String, String> {
val request = ChatCompletionRequest(
model = model.name,
user = user,
messages = listOf(
Message(
Role.system.name,
prompt
)
),
n = n,
temperature = temperature,
maxTokens = 256
)

val completions = llm.createChatCompletion(request)
return formatChatOutput(completions.choices)
}

private fun formatChatOutput(completions: List<Choice>): Map<String, String> =
config.outputKeys.associateWith {
completions.joinToString(", ") { it.message.content }
}

private fun formatCompletionOutput(completions: List<CompletionChoice>): Map<String, String> =
config.outputKeys.associateWith {
completions.joinToString(", ") { it.text }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import arrow.core.raise.either
import arrow.core.raise.recover
import com.xebia.functional.AIError
import com.xebia.functional.Document
import com.xebia.functional.llm.openai.LLMModel
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.prompt.PromptTemplate
import com.xebia.functional.vectorstores.VectorStore
Expand All @@ -18,6 +19,7 @@ interface VectorQAChain : Chain {
@Suppress("LongParameterList")
suspend fun VectorQAChain(
llm: OpenAIClient,
llmModel: LLMModel,
vectorStore: VectorStore,
numOfDocs: Int,
outputVariable: String,
Expand All @@ -42,6 +44,7 @@ suspend fun VectorQAChain(
val chain = CombineDocsChain(
llm,
promptTemplate,
llmModel,
documents,
documentVariableName,
outputVariable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,24 @@ data class Usage(
@SerialName("completion_tokens") val completionTokens: Long? = null,
@SerialName("total_tokens") val totalTokens: Long
)

data class LLMModel(
val name: String,
val kind : Kind
) {
enum class Kind {
Completion, Chat
}
companion object {
val GPT_4 = LLMModel("gpt-4", Kind.Chat)
val GPT_4_0314 = LLMModel("gpt-4-0314", Kind.Chat)
val GPT_4_32K = LLMModel("gpt-4-32k", Kind.Chat)
val GPT_3_5_TURBO = LLMModel("gpt3.5-turbo", Kind.Chat)
val GPT_3_5_TURBO_0301 = LLMModel("gpt3.5-turbo-0301", Kind.Chat)
val TEXT_DAVINCI_003 = LLMModel("text-davinci-003", Kind.Completion)
val TEXT_DAVINCI_002 = LLMModel("text-davinci-002", Kind.Completion)
val TEXT_CURIE_001 = LLMModel("text-curie-001", Kind.Completion)
val TEXT_BABBAGE_001 = LLMModel("text-babbage-001", Kind.Completion)
val TEXT_ADA_001 = LLMModel("text-ada-001", Kind.Completion)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,51 @@
package com.xebia.functional.chains

import com.xebia.functional.llm.openai.ChatCompletionRequest
import com.xebia.functional.llm.openai.ChatCompletionResponse
import com.xebia.functional.llm.openai.CompletionChoice
import com.xebia.functional.llm.openai.CompletionRequest
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.EmbeddingResult
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.*

val testLLM = object : OpenAIClient {
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
when (request.prompt) {
"Tell me a joke." ->
listOf(CompletionChoice("I'm not good at jokes", 1, finishReason = "foo"))
override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> =
when (request.prompt) {
"Tell me a joke." ->
listOf(CompletionChoice("I'm not good at jokes", 1, finishReason = "foo"))

"My name is foo and I'm 28 years old" ->
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, finishReason = "foo"))
"My name is foo and I'm 28 years old" ->
listOf(CompletionChoice("Hello there! Nice to meet you foo", 1, finishReason = "foo"))

testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
testTemplateInputsFormatted -> listOf(CompletionChoice("Two inputs, right?", 1, finishReason = "foo"))
testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
else -> listOf(CompletionChoice("foo", 1, finishReason = "bar"))
}
testTemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
testTemplateInputsFormatted -> listOf(CompletionChoice("Two inputs, right?", 1, finishReason = "foo"))
testQATemplateFormatted -> listOf(CompletionChoice("I don't know", 1, finishReason = "foo"))
else -> listOf(CompletionChoice("foo", 1, finishReason = "bar"))
}

override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
TODO()
override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
when (request.messages.firstOrNull()?.content) {
"Tell me a joke." -> fakeChatCompletion("I'm not good at jokes")
"My name is foo and I'm 28 years old" -> fakeChatCompletion("Hello there! Nice to meet you foo")
testTemplateFormatted -> fakeChatCompletion("I don't know")
testTemplateInputsFormatted -> fakeChatCompletion("Two inputs, right?")
testQATemplateFormatted -> fakeChatCompletion("I don't know")
else -> fakeChatCompletion("foo")
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
TODO()
private fun fakeChatCompletion(message: String): ChatCompletionResponse =
ChatCompletionResponse(
id = "foo",
`object` = "foo",
created = 1,
model = "foo",
usage = Usage(1, 1, 1),
choices = listOf(
Choice(
Message(
Role.system.name,
message
), "foo", index = 0
)
)
)

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
TODO()
}

val testContext = """foo foo foo
Expand Down
Loading