Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoreAIScope: Refactor code #208

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 91 additions & 198 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt
Original file line number Diff line number Diff line change
Expand Up @@ -195,142 +195,97 @@ class CoreAIScope(
bringFromContext: Int = this.docsInContext,
minResponseTokens: Int
): List<String> {
return when (model.kind) {
LLMModel.Kind.Completion ->
callCompletionEndpoint(
prompt.message,
model,
user,
echo,
numberOfPredictions,
temperature,
bringFromContext,
minResponseTokens
)
LLMModel.Kind.Chat ->
callChatEndpoint(
prompt.message,
model,
user,
numberOfPredictions,
temperature,
bringFromContext,
minResponseTokens
)
LLMModel.Kind.ChatWithFunctions ->
callChatEndpointWithFunctionsSupport(
prompt.message,
model,
functions,
user,
numberOfPredictions,
temperature,
bringFromContext,
minResponseTokens
)
.map { it.arguments }
}
}

private suspend fun callCompletionEndpoint(
prompt: String,
model: LLMModel,
user: String,
echo: Boolean,
n: Int,
temperature: Double,
bringFromContext: Int,
minResponseTokens: Int
): List<String> {
val promptWithContext: String =
promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
createPromptWithContextAwareOfTokens(
ctxInfo = context.similaritySearch(prompt.message, bringFromContext),
modelType = model.modelType,
prompt = prompt.message,
minResponseTokens = minResponseTokens
)

val maxTokens: Int = checkTotalLeftTokens(model.modelType, "", promptWithContext)
fun checkTotalLeftTokens(role: String): Int =
with(model.modelType) {
val roleTokens: Int = encoding.countTokens(role)
val padding = 20 // reserve 20 tokens for additional symbols around the context
val promptTokens: Int = encoding.countTokens(promptWithContext)
val takenTokens: Int = roleTokens + promptTokens + padding
val totalLeftTokens: Int = maxContextLength - takenTokens
if (totalLeftTokens < 0) {
throw AIError.PromptExceedsMaxTokenLength(
promptWithContext,
takenTokens,
maxContextLength
)
}
logger.debug {
"Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
}
totalLeftTokens
}

val request =
suspend fun buildCompletionRequest(): CompletionRequest =
CompletionRequest(
model = model.name,
user = user,
prompt = promptWithContext,
echo = echo,
n = n,
n = numberOfPredictions,
temperature = temperature,
maxTokens = maxTokens
maxTokens = checkTotalLeftTokens("")
)
return AIClient.createCompletion(request).choices.map { it.text }
}

private suspend fun callChatEndpoint(
prompt: String,
model: LLMModel,
user: String,
n: Int,
temperature: Double,
bringFromContext: Int,
minResponseTokens: Int
): List<String> {
val role: String = Role.system.name
val promptWithContext: String =
promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
val messages: List<Message> = listOf(Message(role, promptWithContext))
val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
val request =
ChatCompletionRequest(
fun checkTotalLeftChatTokens(messages: List<Message>): Int {
val maxContextLength: Int = model.modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages, model)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
}
logger.debug {
"Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
}
return totalLeftTokens
}

suspend fun buildChatRequest(): ChatCompletionRequest {
val messages: List<Message> = listOf(Message(Role.system.name, promptWithContext))
return ChatCompletionRequest(
model = model.name,
user = user,
messages = messages,
n = n,
n = numberOfPredictions,
temperature = temperature,
maxTokens = maxTokens
maxTokens = checkTotalLeftChatTokens(messages)
)
return AIClient.createChatCompletion(request).choices.map { it.message.content }
}
}

private suspend fun callChatEndpointWithFunctionsSupport(
prompt: String,
model: LLMModel,
functions: List<CFunction>,
user: String,
n: Int,
temperature: Double,
bringFromContext: Int,
minResponseTokens: Int
): List<FunctionCall> {
val role: String = Role.user.name
val firstFnName: String? = functions.firstOrNull()?.name
val promptWithContext: String =
promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
val messages: List<Message> = listOf(Message(role, promptWithContext))
val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
val request =
ChatCompletionRequestWithFunctions(
suspend fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions {
val role: String = Role.user.name
val firstFnName: String? = functions.firstOrNull()?.name
val messages: List<Message> = listOf(Message(role, promptWithContext))
return ChatCompletionRequestWithFunctions(
model = model.name,
user = user,
messages = messages,
n = n,
n = numberOfPredictions,
temperature = temperature,
maxTokens = maxTokens,
maxTokens = checkTotalLeftChatTokens(messages),
functions = functions,
functionCall = mapOf("name" to (firstFnName ?: ""))
)
return AIClient.createChatCompletionWithFunctions(request).choices.map {
it.message.functionCall
}
}

private suspend fun promptWithContext(
prompt: String,
bringFromContext: Int,
modelType: ModelType,
minResponseTokens: Int
): String {
val ctxInfo: List<String> = context.similaritySearch(prompt, bringFromContext)
return createPromptWithContextAwareOfTokens(
ctxInfo = ctxInfo,
modelType = modelType,
prompt = prompt,
minResponseTokens = minResponseTokens
)
return when (model.kind) {
LLMModel.Kind.Completion ->
AIClient.createCompletion(buildCompletionRequest()).choices.map { it.text }
LLMModel.Kind.Chat ->
AIClient.createChatCompletion(buildChatRequest()).choices.map { it.message.content }
LLMModel.Kind.ChatWithFunctions ->
AIClient.createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.map {
it.message.functionCall.arguments
}
}
}

private fun createPromptWithContextAwareOfTokens(
Expand Down Expand Up @@ -368,103 +323,41 @@ class CoreAIScope(
} else prompt
}

private fun checkTotalLeftTokens(
modelType: ModelType,
role: String,
promptWithContext: String
): Int =
with(modelType) {
val roleTokens: Int = encoding.countTokens(role)
val padding = 20 // reserve 20 tokens for additional symbols around the context
val promptTokens: Int = encoding.countTokens(promptWithContext)
val takenTokens: Int = roleTokens + promptTokens + padding
val totalLeftTokens: Int = maxContextLength - takenTokens
if (totalLeftTokens < 0) {
throw AIError.PromptExceedsMaxTokenLength(promptWithContext, takenTokens, maxContextLength)
}
private fun tokensFromMessages(messages: List<Message>, model: LLMModel): Int {
fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
messages.sumOf { message ->
countTokens(message.role) +
countTokens(message.content) +
tokensPerMessage +
(message.name?.let { tokensPerName } ?: 0)
} + 3

fun fallBackTo(fallbackModel: LLMModel, paddingTokens: Int): Int {
logger.debug {
"Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
"Warning: ${model.name} may change over time. " +
"Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
}
totalLeftTokens
return tokensFromMessages(messages, fallbackModel) + paddingTokens
}

private fun checkTotalLeftChatTokens(messages: List<Message>, model: LLMModel): Int {
val maxContextLength: Int = model.modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages, model)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
}
logger.debug {
"Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
}
return totalLeftTokens
}

private fun tokensFromMessages(messages: List<Message>, model: LLMModel): Int =
when (model) {
LLMModel.GPT_3_5_TURBO_FUNCTIONS -> {
val paddingTokens = 200 // reserved for functions
val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
logger.debug {
"Warning: ${model.name} may change over time. " +
"Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
}
tokensFromMessages(messages, fallbackModel) + paddingTokens
}
LLMModel.GPT_3_5_TURBO -> {
val paddingTokens = 5 // otherwise if the model changes, it might later fail
val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
logger.debug {
"Warning: ${model.name} may change over time. " +
"Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
}
tokensFromMessages(messages, fallbackModel) + paddingTokens
}
return when (model) {
LLMModel.GPT_3_5_TURBO_FUNCTIONS ->
// paddingToken = 200: reserved for functions
fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 200)
LLMModel.GPT_3_5_TURBO ->
// otherwise if the model changes, it might later fail
fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 5)
LLMModel.GPT_4,
LLMModel.GPT_4_32K -> {
val paddingTokens = 5 // otherwise if the model changes, it might later fail
val fallbackModel: LLMModel = LLMModel.GPT_4_0314
logger.debug {
"Warning: ${model.name} may change over time. " +
"Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
}
tokensFromMessages(messages, fallbackModel) + paddingTokens
}
LLMModel.GPT_4_32K ->
// otherwise if the model changes, it might later fail
fallBackTo(fallbackModel = LLMModel.GPT_4_0314, paddingTokens = 5)
LLMModel.GPT_3_5_TURBO_0301 ->
model.modelType.encoding.countTokensFromMessages(
messages,
tokensPerMessage = 4,
tokensPerName = 0
)
model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 4, tokensPerName = 0)
LLMModel.GPT_4_0314 ->
model.modelType.encoding.countTokensFromMessages(
messages,
tokensPerMessage = 3,
tokensPerName = 2
)
else -> {
val paddingTokens = 20
val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
logger.debug {
"Warning: calculation of tokens is partially supported for ${model.name} . " +
"Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
}
tokensFromMessages(messages, fallbackModel) + paddingTokens
}
model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 3, tokensPerName = 2)
else -> fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 20)
}

private fun Encoding.countTokensFromMessages(
messages: List<Message>,
tokensPerMessage: Int,
tokensPerName: Int
): Int =
messages.sumOf { message ->
countTokens(message.role) +
countTokens(message.content) +
tokensPerMessage +
(message.name?.let { tokensPerName } ?: 0)
} + 3
}

/**
* Run a [prompt] describes the images you want to generate within the context of [CoreAIScope].
Expand Down