|
1 | 1 | package com.xebia.functional.tokenizer
|
2 | 2 |
|
| 3 | +import kotlin.math.roundToInt |
| 4 | + |
3 | 5 | /** The result of encoding operation. */
|
4 | 6 | data class EncodingResult(
|
5 | 7 | val tokens: List<Int>,
|
@@ -199,3 +201,34 @@ interface Encoding {
|
199 | 201 | */
|
200 | 202 | fun decodeBytes(tokens: List<Int>): ByteArray
|
201 | 203 | }
|
| 204 | + |
| 205 | +/** |
| 206 | + * Truncates the given [text] to the given [maxTokens] by removing tokens from the end of the text. |
| 207 | + * It removes tokens from the tail of the [text]. |
| 208 | + * Tokens are chosen to be removed based on the percentage of the [maxTokens] |
| 209 | + * compared to the total amount of tokens in the [text]. |
| 210 | + * |
| 211 | + * If the truncation fails, |
| 212 | + * it will retry by recursively calling this function until a text with maxTokens is found. |
| 213 | + * |
| 214 | + * **WARNING:** for small [maxTokens] this function may hang forever, |
| 215 | + * some [text] like emoticons, or special characters have token length of 9. |
| 216 | + * So trying to truncateText to maxToken = 5 might hang forever for them. |
| 217 | + * |
| 218 | + * **WARNING:** This method might truncate crucial information from your prompt, |
| 219 | + * and as a result might degrade reliability of your prompts. |
| 220 | + */ |
| 221 | +tailrec fun Encoding.truncateText(text: String, maxTokens: Int): String { |
| 222 | + val tokenCount = countTokens(text) |
| 223 | + return if (tokenCount <= maxTokens) text |
| 224 | + else { |
| 225 | + val percentage = maxTokens.toDouble() / tokenCount.toDouble() |
| 226 | + val truncatedTextLength = (text.length * percentage).roundToInt() |
| 227 | + val result = text.substring(0, truncatedTextLength) |
| 228 | + val tokenCountResult = countTokens(result) |
| 229 | + when { |
| 230 | + tokenCountResult >= maxTokens -> truncateText(result, maxTokens) |
| 231 | + else -> result |
| 232 | + } |
| 233 | + } |
| 234 | +} |
0 commit comments