Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text Pattern Prompts #103

Merged
merged 7 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private suspend fun AIScope.callCompletionEndpoint(
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int,
minResponseTokens: Int,
minResponseTokens: Int
): List<String> {
val promptWithContext: String =
promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ data class ChatCompletionRequest(
@SerialName("max_tokens") val maxTokens: Int? = null,
@SerialName("presence_penalty") val presencePenalty: Double = 0.0,
@SerialName("frequency_penalty") val frequencyPenalty: Double = 0.0,
@SerialName("logit_bias") val logitBias: Map<String, Double>? = emptyMap(),
@SerialName("logit_bias") val logitBias: Map<String, Int> = emptyMap(),
val user: String?
)

Expand Down
161 changes: 161 additions & 0 deletions core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package com.xebia.functional.xef.agents

import com.xebia.functional.tokenizer.EncodingType
import com.xebia.functional.xef.auto.AIScope
import com.xebia.functional.xef.llm.openai.ChatCompletionRequest
import com.xebia.functional.xef.llm.openai.CompletionRequest
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.llm.openai.Message
import com.xebia.functional.xef.llm.openai.Role
import java.util.regex.Matcher

suspend fun AIScope.patternPrompt(
prompt: String,
pattern: Regex,
model: LLMModel = LLMModel.GPT_3_5_TURBO,
user: String = "testing",
n: Int = 1,
echo: Boolean = false,
temperature: Double = 0.0,
maxNewTokens: Int = 30,
stopAfterMatch: Boolean = true
): String =
patternPrompt(
prompt,
pattern,
model,
user,
n,
echo,
temperature,
maxNewTokens,
stopAfterMatch,
genTokens = 0,
partialCompletion = "",
tokenFilter = TokenFilter(model.modelType.encodingType)
)

private suspend fun AIScope.patternPrompt(
prompt: String,
pattern: Regex,
model: LLMModel,
user: String,
n: Int,
echo: Boolean,
temperature: Double,
maxNewTokens: Int,
stopAfterMatch: Boolean,
genTokens: Int,
partialCompletion: String,
tokenFilter: TokenFilter
): String {
if (genTokens >= maxNewTokens) return partialCompletion

val logitBias: Map<String, Int> = tokenFilter.buildLogitBias(partialCompletion, pattern)

val outputCompletion: List<String> =
patternPrompt(model, user, prompt, echo, n, temperature, logitBias)

val nextPartialCompletion: String = partialCompletion + outputCompletion[0]
val nextPromptPlusCompletion: String = prompt + outputCompletion[0]

if (stopAfterMatch && pattern.matches(nextPartialCompletion)) {
return nextPartialCompletion
}

println(nextPromptPlusCompletion)

return patternPrompt(
nextPromptPlusCompletion,
pattern,
model,
user,
n,
echo,
temperature,
maxNewTokens,
stopAfterMatch,
genTokens = genTokens + 1,
nextPartialCompletion,
tokenFilter
)
}

private suspend fun AIScope.patternPrompt(
model: LLMModel,
user: String,
prompt: String,
echo: Boolean,
n: Int,
temperature: Double,
logitBias: Map<String, Int>
): List<String> =
when (model.kind) {
LLMModel.Kind.Completion -> {
val request =
CompletionRequest(
model = model.name,
user = user,
prompt = prompt,
echo = echo,
n = n,
temperature = temperature,
maxTokens = 1,
logitBias = logitBias
)
openAIClient.createCompletion(request).choices.map { it.text }
}
LLMModel.Kind.Chat -> {
val role: String = Role.system.name
val request =
ChatCompletionRequest(
model = model.name,
messages = listOf(Message(role, prompt)),
temperature = temperature,
n = n,
user = user,
maxTokens = 1,
logitBias = logitBias
)
openAIClient.createChatCompletion(request).choices.map { it.message.content }
}
}

interface TokenFilter {
val tokensCache: Map<Int, String>

fun buildLogitBias(partialCompletion: String, pattern: Regex): Map<String, Int>

companion object {
operator fun invoke(encodingType: EncodingType): TokenFilter =
object : TokenFilter {
override val tokensCache: Map<Int, String> = encodingType.buildDecodedTokensCache()

override fun buildLogitBias(partialCompletion: String, pattern: Regex): Map<String, Int> =
buildMap {
val openAILimit = 300
val exclusiveBias = 100
tokensCache
.asSequence()
.filter { pattern.partialMatch(partialCompletion + it.value) }
.take(openAILimit)
.forEach { put("${it.key}", exclusiveBias) }
}

private fun EncodingType.buildDecodedTokensCache(): Map<Int, String> = buildMap {
base.lineSequence().forEach { line ->
val (_, rank) = line.split(Regex("\\s+"), limit = 2)
val tokenId: Int = rank.toInt()
val token: String = encodingType.encoding.decode(listOf(tokenId))
put(tokenId, token)
}
specialTokensBase.forEach { put(it.value, it.key) }
}

private fun Regex.partialMatch(input: String): Boolean {
val matcher: Matcher = toPattern().matcher(input)
return matcher.matches().or(matcher.hitEnd())
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.xebia.functional.xef.auto.pattern

import com.xebia.functional.xef.agents.patternPrompt
import com.xebia.functional.xef.auto.ai
import com.xebia.functional.xef.auto.getOrElse
import com.xebia.functional.xef.auto.prompt
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json

suspend fun main() {
val enableComparison = false

ai {
val goal = "Return the first three letters of the alphabet in a json array: "
val patternResponse: String = patternPrompt(
prompt = goal,
pattern = Regex("""\["[a-z]", "[a-z]", "[a-z]"]"""),
maxNewTokens = 20
)
val list: List<String> = Json.decodeFromString(patternResponse)
println(list)

if (enableComparison) {
val response: List<String> = prompt(goal)
println(response)
}

}.getOrElse { println(it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,124 +2,84 @@

package com.xebia.functional.tokenizer

import com.xebia.functional.tokenizer.internal.SPECIAL_TOKENS_CL100K_BASE
import com.xebia.functional.tokenizer.internal.SPECIAL_TOKENS_P50K_EDIT
import com.xebia.functional.tokenizer.internal.SPECIAL_TOKENS_X50K_BASE
import com.xebia.functional.tokenizer.internal.cl100k_base
import com.xebia.functional.tokenizer.internal.p50k_base
import com.xebia.functional.tokenizer.internal.r50k_base
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi

enum class EncodingType(@Suppress("UNUSED_PARAMETER") name: String) {
R50K_BASE("r50k_base") {
override val encoding by lazy { EncodingFactory.r50kBase() }
},
P50K_BASE("p50k_base") {
override val encoding by lazy { EncodingFactory.p50kBase() }
},
P50K_EDIT("p50k_edit") {
override val encoding by lazy { EncodingFactory.p50kEdit() }
},
CL100K_BASE("cl100k_base") {
override val encoding by lazy { EncodingFactory.cl100kBase() }
};
R50K_BASE("r50k_base") {
override val base: String = r50k_base
override val regex: Regex = p50k_regex
override val specialTokensBase: Map<String, Int> = SPECIAL_TOKENS_P50K_EDIT
override val encoding by lazy {
EncodingFactory.fromPredefinedParameters(
name, regex, base, specialTokensBase
)
}
},
P50K_BASE("p50k_base") {
override val base: String = p50k_base
override val regex: Regex = p50k_regex
override val specialTokensBase: Map<String, Int> = SPECIAL_TOKENS_X50K_BASE
override val encoding by lazy {
EncodingFactory.fromPredefinedParameters(
name, regex, base, specialTokensBase
)
}
},
P50K_EDIT("p50k_edit") {
override val base: String = p50k_base
override val regex: Regex = p50k_regex
override val specialTokensBase: Map<String, Int> = SPECIAL_TOKENS_P50K_EDIT
override val encoding by lazy {
EncodingFactory.fromPredefinedParameters(
name, regex, base, specialTokensBase
)
}
},
CL100K_BASE("cl100k_base") {
override val base: String = cl100k_base
override val regex: Regex = cl100k_base_regex
override val specialTokensBase: Map<String, Int> = SPECIAL_TOKENS_CL100K_BASE
override val encoding by lazy {
EncodingFactory.fromPredefinedParameters(
name, regex, base, specialTokensBase
)
}
};

abstract val encoding: Encoding
abstract val base: String
abstract val regex: Regex
abstract val specialTokensBase: Map<String, Int>
abstract val encoding: Encoding
}

private object EncodingFactory {
private const val ENDOFTEXT = "<|endoftext|>"
private const val FIM_PREFIX = "<|fim_prefix|>"
private const val FIM_MIDDLE = "<|fim_middle|>"
private const val FIM_SUFFIX = "<|fim_suffix|>"
private const val ENDOFPROMPT = "<|endofprompt|>"

private val SPECIAL_TOKENS_X50K_BASE: Map<String, Int> = HashMap<String, Int>(1).apply {
put(ENDOFTEXT, 50256)
}

private val SPECIAL_TOKENS_P50K_EDIT: Map<String, Int> = HashMap<String, Int>(4).apply {
put(ENDOFTEXT, 50256)
put(FIM_PREFIX, 50281)
put(FIM_MIDDLE, 50282)
put(FIM_SUFFIX, 50283)
}

private val SPECIAL_TOKENS_CL100K_BASE: Map<String, Int> = HashMap<String, Int>(5).apply {
put(ENDOFTEXT, 100257)
put(FIM_PREFIX, 100258)
put(FIM_MIDDLE, 100259)
put(FIM_SUFFIX, 100260)
put(ENDOFPROMPT, 100276)
}

/**
* Returns an [Encoding] instance for the r50k_base encoding.
*
* @return an [Encoding] instance for the r50k_base encoding
*/
fun r50kBase(): Encoding = fromPredefinedParameters(
"r50k_base",
p50k_regex,
r50k_base,
SPECIAL_TOKENS_X50K_BASE
)

/**
* Returns an [Encoding] instance for the p50k_base encoding.
*
* @return an [Encoding] instance for the p50k_base encoding
*/
fun p50kBase(): Encoding = fromPredefinedParameters(
"p50k_base",
p50k_regex,
p50k_base,
SPECIAL_TOKENS_X50K_BASE
)

/**
* Returns an [Encoding] instance for the p50k_edit encoding.
*
* @return an [Encoding] instance for the p50k_edit encoding
*/
fun p50kEdit(): Encoding = fromPredefinedParameters(
"p50k_edit",
p50k_regex,
p50k_base,
SPECIAL_TOKENS_P50K_EDIT
)

fun cl100kBase(): Encoding = fromPredefinedParameters(
"cl100k_base",
cl100k_base_regex,
cl100k_base,
SPECIAL_TOKENS_CL100K_BASE
)

/**
* Returns an [Encoding] instance for the given GPT BytePairEncoding parameters.
*
* @param parameters the GPT BytePairEncoding parameters
* @return an [Encoding] instance for the given GPT BytePairEncoding parameters
*/
fun fromParameters(parameters: GptBytePairEncodingParams): Encoding =
GptBytePairEncoding(parameters)
fun fromPredefinedParameters(
name: String,
regex: Regex,
base: String,
specialTokens: Map<String, Int>
): Encoding {
val params = GptBytePairEncodingParams(name, regex, loadMergeableRanks(base), specialTokens)
return fromParameters(params)
}

private fun fromPredefinedParameters(
name: String,
regex: Regex,
base: String,
specialTokens: Map<String, Int>
): Encoding {
val params = GptBytePairEncodingParams(name, regex, loadMergeableRanks(base), specialTokens)
return fromParameters(params)
}
private fun fromParameters(parameters: GptBytePairEncodingParams): Encoding =
GptBytePairEncoding(parameters)

private fun loadMergeableRanks(base: String): Map<ByteArray, Int> =
buildMap {
base.lineSequence().forEach { line ->
val (token, rank) = line.split(Regex("\\s+"), limit = 2)
put(Base64.decode(token.encodeToByteArray()), rank.toInt())
}
}
fun loadMergeableRanks(base: String): Map<ByteArray, Int> =
buildMap {
base.lineSequence().forEach { line ->
val (token, rank) = line.split(Regex("\\s+"), limit = 2)
put(Base64.decode(token.encodeToByteArray()), rank.toInt())
}
}
}

expect val cl100k_base_regex: Regex
Expand Down
Loading