Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression Language for LLM driven template replacements #298

Merged
merged 5 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ kotlin {
api(libs.kotlinx.serialization.json)
api(libs.ktor.utils)
api(projects.xefTokenizer)

implementation(libs.bundles.ktor.client)
implementation(libs.klogging)
implementation(libs.uuid)
}
Expand All @@ -87,21 +87,42 @@ kotlin {
implementation(libs.logback)
implementation(libs.skrape)
implementation(libs.rss.reader)
api(libs.ktor.client.cio)
}
}

val jsMain by getting
val jsMain by getting {
dependencies {
api(libs.ktor.client.js)
}
}

val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
}
}

val linuxX64Main by getting
val macosX64Main by getting
val macosArm64Main by getting
val mingwX64Main by getting
val linuxX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val macosX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val macosArm64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val mingwX64Main by getting {
dependencies {
implementation(libs.ktor.client.winhttp)
}
}
val linuxX64Test by getting
val macosX64Test by getting
val macosArm64Test by getting
Expand Down
70 changes: 38 additions & 32 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ interface Chat : LLM {
): Flow<String> = flow {
val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: String =
val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
Expand All @@ -55,7 +55,7 @@ interface Chat : LLM {
minResponseTokens = promptConfiguration.minResponseTokens
)

val messages: List<Message> = messages(memories, promptWithContext)
val messages: List<Message> = messagesFromMemory(memories) + promptWithContext

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
Expand Down Expand Up @@ -138,32 +138,22 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
messages: List<Message>,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: String =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)

val messages: List<Message> = messages(memories, promptWithContext)
val allMessages = messagesFromMemory(memories) + messages

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val messagesTokens: Int = tokensFromMessages(allMessages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}
Expand Down Expand Up @@ -217,6 +207,29 @@ interface Chat : LLM {
}
}

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)

return promptMessages(promptWithContext, context, conversationId, functions, promptConfiguration)
}

private suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
request: ChatCompletionRequestWithFunctions,
context: VectorStore,
Expand Down Expand Up @@ -274,8 +287,8 @@ interface Chat : LLM {
}
}

private fun messages(memories: List<Memory>, promptWithContext: String): List<Message> =
memories.map { it.content } + listOf(Message(Role.USER, promptWithContext, Role.USER.name))
private fun messagesFromMemory(memories: List<Memory>): List<Message> =
memories.map { it.content }

private suspend fun memories(
conversationId: ConversationId?,
Expand All @@ -288,13 +301,13 @@ interface Chat : LLM {
emptyList()
}

private fun createPromptWithContextAwareOfTokens(
private suspend fun createPromptWithContextAwareOfTokens(
memories: List<Memory>,
ctxInfo: List<String>,
modelType: ModelType,
prompt: String,
minResponseTokens: Int,
): String {
): List<Message> {
val maxContextLength: Int = modelType.maxContextLength
val promptTokens: Int = modelType.encoding.countTokens(prompt)
val memoryTokens = tokensFromMessages(memories.map { it.content })
Expand All @@ -311,17 +324,10 @@ interface Chat : LLM {
// alternatively we could summarize the context, but that's not implemented yet
val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)

"""|```Context
|${ctxTruncated}
|```
|The context is related to the question try to answer the `goal` as best as you can
|or provide information about the found content
|```goal
|${prompt}
|```
|ANSWER:
|"""
.trimMargin()
} else prompt
listOf(
Message.assistantMessage { "Context: $ctxTruncated" },
Message.userMessage { prompt }
)
} else listOf(Message.userMessage { prompt })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.xebia.functional.xef.auto.AiDsl
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.prompt.Prompt
Expand Down Expand Up @@ -45,6 +46,29 @@ interface ChatWithFunctions : Chat {
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A = prompt(prompt, context, conversationId, functions, serializer, promptConfiguration)

@AiDsl
suspend fun <A> prompt(
messages: List<Message>,
context: VectorStore,
serializer: KSerializer<A>,
conversationId: ConversationId? = null,
functions: List<CFunction> = generateCFunction(serializer.descriptor),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A {
return tryDeserialize(
{ json -> Json.decodeFromString(serializer, json) },
promptConfiguration.maxDeserializationAttempts
) {
promptMessages(
messages = messages,
context = context,
conversationId = conversationId,
functions = functions,
promptConfiguration
)
}
}

@AiDsl
suspend fun <A> prompt(
prompt: Prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
package com.xebia.functional.xef.llm.models.chat

data class Message(val role: Role, val content: String, val name: String)
data class Message(val role: Role, val content: String, val name: String) {
companion object {
suspend fun systemMessage(message: suspend () -> String) =
Message(role = Role.SYSTEM, content = message(), name = Role.SYSTEM.name)

suspend fun userMessage(message: suspend () -> String) =
Message(role = Role.USER, content = message(), name = Role.USER.name)

suspend fun assistantMessage(message: suspend () -> String) =
Message(role = Role.ASSISTANT, content = message(), name = Role.ASSISTANT.name)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.CoreAIScope
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.llm.ChatWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.prompt.experts.ExpertSystem
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging

class Expression(
private val scope: CoreAIScope,
private val model: ChatWithFunctions,
val block: suspend Expression.() -> Unit
) {

private val logger: KLogger = KotlinLogging.logger {}

private val messages: MutableList<Message> = mutableListOf()

private val generationKeys: MutableList<String> = mutableListOf()

suspend fun system(message: suspend () -> String) {
messages.add(Message.systemMessage(message))
}

suspend fun user(message: suspend () -> String) {
messages.add(Message.userMessage(message))
}

suspend fun assistant(message: suspend () -> String) {
messages.add(Message.assistantMessage(message))
}

fun prompt(key: String): String {
generationKeys.add(key)
return "{{$key}}"
}

suspend fun run(
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): ExpressionResult {
block()
val instructionMessage =
Message(
role = Role.USER,
content =
ExpertSystem(
system = "You are an expert in replacing variables in templates",
query =
"""
|I want to replace the following variables in the following template:
|<template>
|${messages.joinToString("\n") { it.content }}
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
|</template>
|The variables are:
|${generationKeys.joinToString("\n") { it }}
"""
.trimMargin(),
instructions =
listOf(
"Create a `ReplacedValues` object with the `replacements` where the keys are the variable names and the values are the values to replace them with.",
)
)
.message,
name = Role.USER.name
)
val values: ReplacedValues =
model.prompt(
messages = messages + instructionMessage,
context = scope.context,
serializer = ReplacedValues.serializer(),
conversationId = scope.conversationId,
promptConfiguration = promptConfiguration
)
logger.info { "replaced: ${values.replacements.joinToString { it.key }}" }
val replacedTemplate =
messages.fold("") { acc, message ->
val replacedMessage =
generationKeys.fold(message.content) { acc, key ->
acc.replace(
"{{$key}}",
values.replacements.firstOrNull { it.key == key }?.value ?: "{{$key}}"
)
}
acc + replacedMessage + "\n"
}
return ExpressionResult(messages = messages, result = replacedTemplate, values = values)
}

companion object {
suspend fun run(
scope: CoreAIScope,
model: ChatWithFunctions,
block: suspend Expression.() -> Unit
): ExpressionResult = Expression(scope, model, block).run()

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.llm.models.chat.Message

data class ExpressionResult(
val messages: List<Message>,
val result: String,
val values: ReplacedValues,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.Description
import kotlinx.serialization.Serializable

@Serializable
data class ReplacedValues(
@Description(["The values that are generated for the template"])
val replacements: List<Replacement>
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.Description
import kotlinx.serialization.Serializable

@Serializable
data class Replacement(
@Description(["The key originally in {{key}} format that was going to get replaced"])
val key: String,
@Description(["The Assistant generated value that the `key` should be replaced with"])
val value: String
)
Loading