Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM SQL interface + Fixes to TOKEN limits in prompts. #81

Merged
merged 8 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ sealed interface AIError {
get() = "No response from the AI"
}

data class PromptExceedsMaxTokenLength(
val prompt: String,
val promptTokens: Int,
val maxTokens: Int
) : AIError {
override val reason: String =
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}

data class JsonParsing(
val result: String,
val maxAttempts: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ suspend fun <A> AIScope.prompt(
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
): A {
val serializationConfig: SerializationConfig<A> =
SerializationConfig(
Expand Down Expand Up @@ -128,7 +129,8 @@ suspend fun <A> AIScope.prompt(
echo,
n,
temperature,
bringFromContext
bringFromContext,
minResponseTokens
)
}
}
Expand Down
160 changes: 130 additions & 30 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package com.xebia.functional.xef.auto

import com.xebia.functional.xef.llm.openai.ChatCompletionRequest
import com.xebia.functional.xef.llm.openai.CompletionRequest
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.llm.openai.Message
import com.xebia.functional.xef.llm.openai.Role
import arrow.core.raise.Raise
import arrow.core.raise.ensure
import com.xebia.functional.tokenizer.truncateText
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.*
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging

private val logger: KLogger by lazy { KotlinLogging.logger {} }
nomisRev marked this conversation as resolved.
Show resolved Hide resolved

@AiDsl
suspend fun AIScope.promptMessage(
Expand All @@ -15,9 +19,19 @@ suspend fun AIScope.promptMessage(
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10
bringFromContext: Int = 10,
minResponseTokens: Int = 500
): List<String> =
promptMessage(Prompt(question), model, user, echo, n, temperature, bringFromContext)
promptMessage(
Prompt(question),
model,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens
)

@AiDsl
suspend fun AIScope.promptMessage(
Expand All @@ -27,46 +41,92 @@ suspend fun AIScope.promptMessage(
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10
bringFromContext: Int = 10,
minResponseTokens: Int
): List<String> {
val ctxInfo = context.similaritySearch(prompt.message, bringFromContext)
val promptWithContext =
if (ctxInfo.isNotEmpty()) {
"""|Instructions: Use the [Information] below delimited by 3 backticks to accomplish
|the [Objective] at the end of the prompt.
|Try to match the data returned in the [Objective] with this [Information] as best as you can.
|[Information]:
|```
|${ctxInfo.joinToString("\n")}
|```
|$prompt"""
.trimMargin()
} else prompt.message

return when (model.kind) {
LLMModel.Kind.Completion ->
callCompletionEndpoint(promptWithContext, model, user, echo, n, temperature)
LLMModel.Kind.Chat -> callChatEndpoint(promptWithContext, model, user, n, temperature)
callCompletionEndpoint(
prompt.message,
model,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens
)
LLMModel.Kind.Chat ->
callChatEndpoint(
prompt.message,
model,
user,
n,
temperature,
bringFromContext,
minResponseTokens
)
}
}

private fun Raise<AIError>.createPromptWithContextAwareOfTokens(
ctxInfo: List<String>,
model: LLMModel,
prompt: String,
minResponseTokens: Int,
): String {
val remainingTokens =
model.modelType.maxContextLength -
model.modelType.encoding.countTokens(prompt) -
minResponseTokens
return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
val ctx = ctxInfo.joinToString("\n")
val promptTokens = model.modelType.encoding.countTokens(prompt)
ensure(promptTokens < model.modelType.maxContextLength) {
raise(
AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, model.modelType.maxContextLength)
)
}
// truncate the context if it's too long based on the max tokens calculated considering the
// existing prompt tokens
// alternatively we could summarize the context, but that's not implemented yet
val maxTokens = model.modelType.maxContextLength - promptTokens - minResponseTokens
val ctxTruncated = model.modelType.encoding.truncateText(ctx, maxTokens)
"""|```Context
|${ctxTruncated}
|```
|The context is related to the question try to answer the `goal` as best as you can
|or provide information about the found content
|```goal
|${prompt}
|```
|ANSWER:
|"""
.trimMargin()
} else prompt
}

private suspend fun AIScope.callCompletionEndpoint(
prompt: String,
model: LLMModel,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0
temperature: Double = 0.0,
bringFromContext: Int,
minResponseTokens: Int,
): List<String> {
val (promptWithContext, maxTokens) =
promptWithContextAndRemainingTokens("", prompt, bringFromContext, model, minResponseTokens)
val request =
CompletionRequest(
model = model.name,
user = user,
prompt = prompt,
prompt = promptWithContext,
echo = echo,
n = n,
temperature = temperature,
maxTokens = 1024
maxTokens = maxTokens
)
return openAIClient.createCompletion(request).map { it.text }
}
Expand All @@ -76,16 +136,56 @@ private suspend fun AIScope.callChatEndpoint(
model: LLMModel,
user: String = "testing",
n: Int = 1,
temperature: Double = 0.0
temperature: Double = 0.0,
bringFromContext: Int,
minResponseTokens: Int
): List<String> {
val role = Role.system.name
val (promptWithContext, maxTokens) =
promptWithContextAndRemainingTokens(role, prompt, bringFromContext, model, minResponseTokens)
val request =
ChatCompletionRequest(
model = model.name,
user = user,
messages = listOf(Message(Role.system.name, prompt)),
messages = listOf(Message(role, promptWithContext)),
n = n,
temperature = temperature,
maxTokens = 1024
maxTokens = maxTokens
)
return openAIClient.createChatCompletion(request).choices.map { it.message.content }
}

private suspend fun AIScope.promptWithContextAndRemainingTokens(
role: String,
prompt: String,
bringFromContext: Int,
model: LLMModel,
minResponseTokens: Int
): Pair<String, Int> {
val ctxInfo = context.similaritySearch(prompt, bringFromContext)
val promptWithContext =
createPromptWithContextAwareOfTokens(
ctxInfo = ctxInfo,
model = model,
prompt = prompt,
minResponseTokens = minResponseTokens
)
val roleTokens = model.modelType.encoding.countTokens(role)
val padding = 20 // reserve 20 tokens for additional symbols around the context
val promptTokens = model.modelType.encoding.countTokens(promptWithContext)
val takenTokens = roleTokens + promptTokens + padding
val totalLeftTokens = model.modelType.maxContextLength - takenTokens
if (totalLeftTokens < 0) {
raise(
AIError.PromptExceedsMaxTokenLength(
promptWithContext,
takenTokens,
model.modelType.maxContextLength
)
)
}
logger.debug {
"Tokens: used: $takenTokens, model max: ${model.modelType.maxContextLength}, left: $totalLeftTokens"
}
return Pair(promptWithContext, totalLeftTokens)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xebia.functional.xef.llm.openai

import com.xebia.functional.tokenizer.ModelType
import kotlin.jvm.JvmInline
import kotlin.jvm.JvmStatic
import kotlinx.serialization.SerialName
Expand Down Expand Up @@ -115,31 +116,37 @@ data class Usage(
@SerialName("total_tokens") val totalTokens: Long
)

data class LLMModel(val name: String, val kind: Kind, val contextLength: Int) {
data class LLMModel(val name: String, val kind: Kind, val modelType: ModelType) {
enum class Kind {
Completion,
Chat
}

companion object {
@JvmStatic val GPT_4 = LLMModel("gpt-4", Kind.Chat, 8192)
@JvmStatic val GPT_4 = LLMModel("gpt-4", Kind.Chat, ModelType.GPT_4)

@JvmStatic val GPT_4_0314 = LLMModel("gpt-4-0314", Kind.Chat, 8192)
@JvmStatic val GPT_4_0314 = LLMModel("gpt-4-0314", Kind.Chat, ModelType.GPT_4)

@JvmStatic val GPT_4_32K = LLMModel("gpt-4-32k", Kind.Chat, 32768)
@JvmStatic val GPT_4_32K = LLMModel("gpt-4-32k", Kind.Chat, ModelType.GPT_4_32K)

@JvmStatic val GPT_3_5_TURBO = LLMModel("gpt-3.5-turbo", Kind.Chat, 4096)
@JvmStatic val GPT_3_5_TURBO = LLMModel("gpt-3.5-turbo", Kind.Chat, ModelType.GPT_3_5_TURBO)

@JvmStatic val GPT_3_5_TURBO_0301 = LLMModel("gpt-3.5-turbo-0301", Kind.Chat, 4096)
@JvmStatic
val GPT_3_5_TURBO_0301 = LLMModel("gpt-3.5-turbo-0301", Kind.Chat, ModelType.GPT_3_5_TURBO)

@JvmStatic val TEXT_DAVINCI_003 = LLMModel("text-davinci-003", Kind.Completion, 4097)
@JvmStatic
val TEXT_DAVINCI_003 = LLMModel("text-davinci-003", Kind.Completion, ModelType.TEXT_DAVINCI_003)

@JvmStatic val TEXT_DAVINCI_002 = LLMModel("text-davinci-002", Kind.Completion, 4097)
@JvmStatic
val TEXT_DAVINCI_002 = LLMModel("text-davinci-002", Kind.Completion, ModelType.TEXT_DAVINCI_002)

@JvmStatic val TEXT_CURIE_001 = LLMModel("text-curie-001", Kind.Completion, 2049)
@JvmStatic
val TEXT_CURIE_001 =
LLMModel("text-curie-001", Kind.Completion, ModelType.TEXT_SIMILARITY_CURIE_001)

@JvmStatic val TEXT_BABBAGE_001 = LLMModel("text-babbage-001", Kind.Completion, 2049)
@JvmStatic
val TEXT_BABBAGE_001 = LLMModel("text-babbage-001", Kind.Completion, ModelType.TEXT_BABBAGE_001)

@JvmStatic val TEXT_ADA_001 = LLMModel("text-ada-001", Kind.Completion, 2049)
@JvmStatic val TEXT_ADA_001 = LLMModel("text-ada-001", Kind.Completion, ModelType.TEXT_ADA_001)
}
}
5 changes: 5 additions & 0 deletions example/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile

plugins {
id(libs.plugins.kotlin.jvm.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
Expand All @@ -19,10 +21,13 @@ dependencies {
implementation(projects.xefCore)
implementation(projects.xefFilesystem)
implementation(projects.xefPdf)
implementation(projects.xefSql)
implementation(projects.tokenizer)
implementation(libs.kotlinx.serialization.json)
implementation(libs.logback)
implementation(libs.klogging)
implementation(libs.bundles.arrow)
implementation(libs.okio)
implementation(libs.jdbc.mysql.connector)
api(libs.bundles.ktor.client)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.xebia.functional.xef.auto.sql

import arrow.core.raise.catch
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.auto.ai
import com.xebia.functional.xef.auto.getOrThrow
import com.xebia.functional.xef.auto.promptMessage
import com.xebia.functional.xef.sql.SQL
import com.xebia.functional.xef.sql.jdbc.JdbcConfig

val config = JdbcConfig(
vendor = System.getenv("XEF_SQL_DB_VENDOR") ?: "mysql",
host = System.getenv("XEF_SQL_DB_HOST") ?: "localhost",
username = System.getenv("XEF_SQL_DB_USER") ?: "user",
password = System.getenv("XEF_SQL_DB_PASSWORD") ?: "password",
port = System.getenv("XEF_SQL_DB_PORT")?.toInt() ?: 3306,
database = System.getenv("XEF_SQL_DB_DATABASE") ?: "database",
llmModelType = ModelType.GPT_3_5_TURBO
)

suspend fun main() = ai {
SQL.fromJdbcConfig(config) {
println("llmdb> Welcome to the LLMDB (An LLM interface to your SQL Database) !")
println("llmdb> You can ask me questions about the database and I will try to answer them.")
println("llmdb> You can type `exit` to exit the program.")
println("llmdb> Loading recommended prompts...")
val interestingPrompts = getInterestingPromptsForDatabase()
interestingPrompts.forEach {
println("llmdb> ${it}")
}
while (true) {
// a cli chat with the content
print("user> ")
val input = readln()
if (input == "exit") break
catch({
extendContext(*promptQuery(input).toTypedArray())
val result = promptMessage("""|
|You are a database assistant that helps users to query and summarize results from the database.
|Instructions:
|1. Summarize the information provided in the `Context` and follow to step 2.
|2. If the information relates to the `input` then answer the question otherwise return just the summary.
|```input
|$input
|```
|3. Try to answer and provide information with as much detail as you can
""".trimMargin(), bringFromContext = 200)
result.forEach {
println("llmdb> ${it}")
}
}, { exception ->
println("llmdb> ${exception.message}")
exception.printStackTrace()
})
}
}
}.getOrThrow()

2 changes: 1 addition & 1 deletion example/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<appender-ref ref="NOOP"/>
</root>

<logger name="AutoAI" level="debug">
<logger name="com.xebia.functional.xef" level="debug">
<appender-ref ref="STDOUT" />
</logger>

Expand Down
Loading