From c5f0d54473af96c433cceb2e463d8cb027e5fb86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Raja=20Mart=C3=ADnez?= Date: Fri, 18 Aug 2023 11:18:59 +0200 Subject: [PATCH] Move memory and prompt calculations out of the Chat interface (#332) * Simplify access to functions and prompts * Code review during pair session with Javi and Jose Carlos * Spotlless aply * spotless * Move memory and prompt calculations out of the Chat interface * Private methods --------- Co-authored-by: Javi Pacheco --- .../com/xebia/functional/xef/llm/Chat.kt | 258 +++--------------- .../functional/xef/llm/MemoryManagement.kt | 104 +++++++ .../functional/xef/llm/PromptCalculator.kt | 119 ++++++++ 3 files changed, 262 insertions(+), 219 deletions(-) create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/llm/MemoryManagement.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index 00dfabd67..dddc4d14a 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -1,16 +1,11 @@ package com.xebia.functional.xef.llm import com.xebia.functional.tokenizer.ModelType -import com.xebia.functional.tokenizer.truncateText import com.xebia.functional.xef.AIError import com.xebia.functional.xef.auto.AiDsl import com.xebia.functional.xef.auto.Conversation import com.xebia.functional.xef.llm.models.chat.* import com.xebia.functional.xef.prompt.Prompt -import com.xebia.functional.xef.prompt.configuration.PromptConfiguration -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.vectorstores.Memory -import io.ktor.util.date.* import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.onCompletion @@ -27,14 +22,14 @@ interface Chat : LLM { @AiDsl fun promptStreaming(prompt: Prompt, scope: Conversation): Flow = flow { - val messagesForRequest = - fitMessagesByTokens(prompt.messages, scope, modelType, prompt.configuration) + val messagesForRequestPrompt = + PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat) val request = ChatCompletionRequest( model = name, user = prompt.configuration.user, - messages = messagesForRequest, + messages = messagesForRequestPrompt.messages, n = prompt.configuration.numberOfPredictions, temperature = prompt.configuration.temperature, maxTokens = prompt.configuration.minResponseTokens, @@ -49,7 +44,7 @@ interface Chat : LLM { buffer.append(text) } } - .onCompletion { addMemoriesAfterStream(request, scope, buffer) } + .onCompletion { MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, buffer) } .collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) } } @@ -60,230 +55,55 @@ interface Chat : LLM { @AiDsl suspend fun promptMessages(prompt: Prompt, scope: Conversation): List { - val messagesForRequest = - fitMessagesByTokens(prompt.messages, scope, modelType, prompt.configuration) + val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat) fun chatRequest(): ChatCompletionRequest = ChatCompletionRequest( model = name, - user = prompt.configuration.user, - messages = messagesForRequest, - n = prompt.configuration.numberOfPredictions, - temperature = prompt.configuration.temperature, - maxTokens = prompt.configuration.minResponseTokens, + user = adaptedPrompt.configuration.user, + messages = adaptedPrompt.messages, + n = adaptedPrompt.configuration.numberOfPredictions, + temperature = adaptedPrompt.configuration.temperature, + maxTokens = adaptedPrompt.configuration.minResponseTokens, ) fun withFunctionsRequest(): ChatCompletionRequestWithFunctions = ChatCompletionRequestWithFunctions( model = name, - user = prompt.configuration.user, - messages = messagesForRequest, - n = prompt.configuration.numberOfPredictions, - temperature = prompt.configuration.temperature, - maxTokens = prompt.configuration.minResponseTokens, - functions = listOfNotNull(prompt.function), - functionCall = mapOf("name" to (prompt.function?.name ?: "")) + user = adaptedPrompt.configuration.user, + messages = adaptedPrompt.messages, + n = adaptedPrompt.configuration.numberOfPredictions, + temperature = adaptedPrompt.configuration.temperature, + maxTokens = adaptedPrompt.configuration.minResponseTokens, + functions = listOfNotNull(adaptedPrompt.function), + functionCall = mapOf("name" to (adaptedPrompt.function?.name ?: "")) ) - return when (this) { - is ChatWithFunctions -> - // we only support functions for now with GPT_3_5_TURBO_FUNCTIONS - if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) { - val request = withFunctionsRequest() - createChatCompletionWithFunctions(request) - .choices - .addChoiceWithFunctionsToMemory(request, scope) - .mapNotNull { it.message?.functionCall?.arguments } - } else { - val request = chatRequest() - createChatCompletion(request).choices.addChoiceToMemory(request, scope).mapNotNull { - it.message?.content + return MemoryManagement.run { + when (this@Chat) { + is ChatWithFunctions -> + // we only support functions for now with GPT_3_5_TURBO_FUNCTIONS + if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) { + val request = withFunctionsRequest() + createChatCompletionWithFunctions(request) + .choices + .addChoiceWithFunctionsToMemory(this@Chat, request, scope) + .mapNotNull { it.message?.functionCall?.arguments } + } else { + val request = chatRequest() + createChatCompletion(request) + .choices + .addChoiceToMemory(this@Chat, request, scope) + .mapNotNull { it.message?.content } } - } - else -> { - val request = chatRequest() - createChatCompletion(request).choices.addChoiceToMemory(request, scope).mapNotNull { - it.message?.content + else -> { + val request = chatRequest() + createChatCompletion(request) + .choices + .addChoiceToMemory(this@Chat, request, scope) + .mapNotNull { it.message?.content } } } } } - - private suspend fun addMemoriesAfterStream( - request: ChatCompletionRequest, - scope: Conversation, - buffer: StringBuilder, - ) { - val lastRequestMessage = request.messages.lastOrNull() - val cid = scope.conversationId - if (cid != null && lastRequestMessage != null) { - val requestMemory = - Memory( - conversationId = cid, - content = lastRequestMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(lastRequestMessage)) - ) - val responseMessage = - Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name) - val responseMemory = - Memory( - conversationId = cid, - content = responseMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(responseMessage)) - ) - scope.store.addMemories(listOf(requestMemory, responseMemory)) - } - } - - private suspend fun List.addChoiceWithFunctionsToMemory( - request: ChatCompletionRequestWithFunctions, - scope: Conversation - ): List = also { - val firstChoice = firstOrNull() - val requestUserMessage = request.messages.lastOrNull() - val cid = scope.conversationId - if (requestUserMessage != null && firstChoice != null && cid != null) { - val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER - - val requestMemory = - Memory( - conversationId = cid, - content = requestUserMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(requestUserMessage)) - ) - val firstChoiceMessage = - Message( - role = role, - content = firstChoice.message?.content - ?: firstChoice.message?.functionCall?.arguments ?: "", - name = role.name - ) - val firstChoiceMemory = - Memory( - conversationId = cid, - content = firstChoiceMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(firstChoiceMessage)) - ) - scope.store.addMemories(listOf(requestMemory, firstChoiceMemory)) - } - } - - private suspend fun List.addChoiceToMemory( - request: ChatCompletionRequest, - scope: Conversation - ): List = also { - val firstChoice = firstOrNull() - val requestUserMessage = request.messages.lastOrNull() - val cid = scope.conversationId - if (requestUserMessage != null && firstChoice != null && cid != null) { - val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER - val requestMemory = - Memory( - conversationId = cid, - content = requestUserMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(requestUserMessage)) - ) - val firstChoiceMessage = - Message(role = role, content = firstChoice.message?.content ?: "", name = role.name) - val firstChoiceMemory = - Memory( - conversationId = cid, - content = firstChoiceMessage, - timestamp = getTimeMillis(), - approxTokens = tokensFromMessages(listOf(firstChoiceMessage)) - ) - scope.store.addMemories(listOf(requestMemory, firstChoiceMemory)) - } - } - - private fun messagesFromMemory(memories: List): List = - memories.map { it.content } - - private suspend fun Conversation.memories(limitTokens: Int): List { - val cid = conversationId - return if (cid != null) { - store.memories(cid, limitTokens) - } else { - emptyList() - } - } - - private suspend fun fitMessagesByTokens( - messages: List, - scope: Conversation, - modelType: ModelType, - promptConfiguration: PromptConfiguration, - ): List { - - // calculate tokens for history and context - val maxContextLength: Int = modelType.maxContextLength - val remainingTokens: Int = maxContextLength - promptConfiguration.minResponseTokens - - val messagesTokens = tokensFromMessages(messages) - - if (messagesTokens >= remainingTokens) { - throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens) - } - - val remainingTokensForContexts = remainingTokens - messagesTokens - - val historyPercent = promptConfiguration.messagePolicy.historyPercent - val contextPercent = promptConfiguration.messagePolicy.contextPercent - - val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100 - val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100 - - // calculate messages for history based on tokens - - val memories: List = - scope.memories(maxHistoryTokens + promptConfiguration.messagePolicy.historyPaddingTokens) - - val historyAllowed = - if (memories.isNotEmpty()) { - val history = messagesFromMemory(memories) - - // since we have the approximate tokens in memory, we need to fit the messages back to the - // number of tokens if necessary - val historyTokens = tokensFromMessages(history) - if (historyTokens <= maxHistoryTokens) history - else { - val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) } - - val totalTokenWithMessages = - historyMessagesWithTokens.foldRight(Pair(0, emptyList())) { pair, acc -> - if (acc.first + pair.second > maxHistoryTokens) { - acc - } else { - Pair(acc.first + pair.second, acc.second + pair.first) - } - } - totalTokenWithMessages.second.reversed() - } - } else emptyList() - - // calculate messages for context based on tokens - val ctxInfo = - scope.store.similaritySearch( - messages.joinToString("\n") { it.content }, - promptConfiguration.docsInContext, - ) - - val contextAllowed = - if (ctxInfo.isNotEmpty()) { - val ctx: String = ctxInfo.joinToString("\n") - - val ctxTruncated: String = modelType.encoding.truncateText(ctx, maxContextTokens) - - Prompt { +assistant(ctxTruncated) }.messages - } else { - emptyList() - } - - return contextAllowed + historyAllowed + messages - } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/MemoryManagement.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/MemoryManagement.kt new file mode 100644 index 000000000..a62cf34a7 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/MemoryManagement.kt @@ -0,0 +1,104 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.xef.auto.Conversation +import com.xebia.functional.xef.llm.models.chat.* +import com.xebia.functional.xef.vectorstores.Memory +import io.ktor.util.date.* + +internal object MemoryManagement { + + internal suspend fun addMemoriesAfterStream( + chat: Chat, + request: ChatCompletionRequest, + scope: Conversation, + buffer: StringBuilder, + ) { + val lastRequestMessage = request.messages.lastOrNull() + val cid = scope.conversationId + if (cid != null && lastRequestMessage != null) { + val requestMemory = + Memory( + conversationId = cid, + content = lastRequestMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(lastRequestMessage)) + ) + val responseMessage = + Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name) + val responseMemory = + Memory( + conversationId = cid, + content = responseMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(responseMessage)) + ) + scope.store.addMemories(listOf(requestMemory, responseMemory)) + } + } + + internal suspend fun List.addChoiceWithFunctionsToMemory( + chat: Chat, + request: ChatCompletionRequestWithFunctions, + scope: Conversation + ): List = also { + val firstChoice = firstOrNull() + val requestUserMessage = request.messages.lastOrNull() + val cid = scope.conversationId + if (requestUserMessage != null && firstChoice != null && cid != null) { + val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER + + val requestMemory = + Memory( + conversationId = cid, + content = requestUserMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(requestUserMessage)) + ) + val firstChoiceMessage = + Message( + role = role, + content = firstChoice.message?.content + ?: firstChoice.message?.functionCall?.arguments ?: "", + name = role.name + ) + val firstChoiceMemory = + Memory( + conversationId = cid, + content = firstChoiceMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage)) + ) + scope.store.addMemories(listOf(requestMemory, firstChoiceMemory)) + } + } + + internal suspend fun List.addChoiceToMemory( + chat: Chat, + request: ChatCompletionRequest, + scope: Conversation + ): List = also { + val firstChoice = firstOrNull() + val requestUserMessage = request.messages.lastOrNull() + val cid = scope.conversationId + if (requestUserMessage != null && firstChoice != null && cid != null) { + val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER + val requestMemory = + Memory( + conversationId = cid, + content = requestUserMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(requestUserMessage)) + ) + val firstChoiceMessage = + Message(role = role, content = firstChoice.message?.content ?: "", name = role.name) + val firstChoiceMemory = + Memory( + conversationId = cid, + content = firstChoiceMessage, + timestamp = getTimeMillis(), + approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage)) + ) + scope.store.addMemories(listOf(requestMemory, firstChoiceMemory)) + } + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt new file mode 100644 index 000000000..3f5f9ae2e --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/PromptCalculator.kt @@ -0,0 +1,119 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.tokenizer.truncateText +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.Conversation +import com.xebia.functional.xef.llm.models.chat.Message +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.templates.assistant +import com.xebia.functional.xef.vectorstores.Memory + +internal object PromptCalculator { + + suspend fun adaptPromptToConversationAndModel( + prompt: Prompt, + scope: Conversation, + chat: Chat + ): Prompt { + + // calculate tokens for history and context + val remainingTokensForContexts = calculateRemainingTokensForContext(chat, prompt) + + val maxHistoryTokens = calculateMaxHistoryTokens(prompt, remainingTokensForContexts) + + val maxContextTokens = calculateMaxContextTokens(prompt, remainingTokensForContexts) + + // calculate messages for history based on tokens + + val memories: List = + scope.memories(maxHistoryTokens + prompt.configuration.messagePolicy.historyPaddingTokens) + + val historyAllowed = calculateMessagesFromHistory(chat, memories, maxHistoryTokens) + + // calculate messages for context based on tokens + val ctxInfo = + scope.store.similaritySearch( + prompt.messages.joinToString("\n") { it.content }, + prompt.configuration.docsInContext, + ) + + val contextAllowed = + if (ctxInfo.isNotEmpty()) { + val ctx: String = ctxInfo.joinToString("\n") + + val ctxTruncated: String = chat.modelType.encoding.truncateText(ctx, maxContextTokens) + + Prompt { +assistant(ctxTruncated) }.messages + } else { + emptyList() + } + + return prompt.copy(messages = contextAllowed + historyAllowed + prompt.messages) + } + + private fun messagesFromMemory(memories: List): List = + memories.map { it.content } + + private fun calculateMessagesFromHistory( + chat: Chat, + memories: List, + maxHistoryTokens: Int + ) = + if (memories.isNotEmpty()) { + val history = messagesFromMemory(memories) + + // since we have the approximate tokens in memory, we need to fit the messages back to the + // number of tokens if necessary + val historyTokens = chat.tokensFromMessages(history) + if (historyTokens <= maxHistoryTokens) history + else { + val historyMessagesWithTokens = + history.map { Pair(it, chat.tokensFromMessages(listOf(it))) } + + val totalTokenWithMessages = + historyMessagesWithTokens.foldRight(Pair(0, emptyList())) { pair, acc -> + if (acc.first + pair.second > maxHistoryTokens) { + acc + } else { + Pair(acc.first + pair.second, acc.second + pair.first) + } + } + totalTokenWithMessages.second.reversed() + } + } else emptyList() + + private fun calculateMaxContextTokens(prompt: Prompt, remainingTokensForContexts: Int): Int { + val contextPercent = prompt.configuration.messagePolicy.contextPercent + val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100 + return maxContextTokens + } + + private fun calculateMaxHistoryTokens(prompt: Prompt, remainingTokensForContexts: Int): Int { + val historyPercent = prompt.configuration.messagePolicy.historyPercent + val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100 + return maxHistoryTokens + } + + private fun calculateRemainingTokensForContext(chat: Chat, prompt: Prompt): Int { + val maxContextLength: Int = chat.modelType.maxContextLength + val remainingTokens: Int = maxContextLength - prompt.configuration.minResponseTokens + + val messagesTokens = chat.tokensFromMessages(prompt.messages) + + if (messagesTokens >= remainingTokens) { + throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens) + } + + val remainingTokensForContexts = remainingTokens - messagesTokens + return remainingTokensForContexts + } + + private suspend fun Conversation.memories(limitTokens: Int): List { + val cid = conversationId + return if (cid != null) { + store.memories(cid, limitTokens) + } else { + emptyList() + } + } +}