Skip to content

Commit 62070bc

Browse files
authored
Add Encoding.truncateText (#84)
* Add Encoding.truncateText * Add warning
1 parent 949402a commit 62070bc

File tree

2 files changed

+53
-0
lines changed
  • tokenizer/src

2 files changed

+53
-0
lines changed

tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/Encoding.kt

+33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.xebia.functional.tokenizer
22

3+
import kotlin.math.roundToInt
4+
35
/** The result of encoding operation. */
46
data class EncodingResult(
57
val tokens: List<Int>,
@@ -199,3 +201,34 @@ interface Encoding {
199201
*/
200202
fun decodeBytes(tokens: List<Int>): ByteArray
201203
}
204+
205+
/**
206+
* Truncates the given [text] to the given [maxTokens] by removing tokens from the end of the text.
207+
* It removes tokens from the tail of the [text].
208+
* Tokens are chosen to be removed based on the percentage of the [maxTokens]
209+
* compared to the total amount of tokens in the [text].
210+
*
211+
* If the truncation fails,
212+
* it will retry by recursively calling this function until a text with maxTokens is found.
213+
*
214+
* **WARNING:** for small [maxTokens] this function may hang forever,
215+
* some [text] like emoticons, or special characters have token length of 9.
216+
* So trying to truncateText to maxToken = 5 might hang forever for them.
217+
*
218+
* **WARNING:** This method might truncate crucial information from your prompt,
219+
* and as a result might degrade reliability of your prompts.
220+
*/
221+
tailrec fun Encoding.truncateText(text: String, maxTokens: Int): String {
222+
val tokenCount = countTokens(text)
223+
return if (tokenCount <= maxTokens) text
224+
else {
225+
val percentage = maxTokens.toDouble() / tokenCount.toDouble()
226+
val truncatedTextLength = (text.length * percentage).roundToInt()
227+
val result = text.substring(0, truncatedTextLength)
228+
val tokenCountResult = countTokens(result)
229+
when {
230+
tokenCountResult >= maxTokens -> truncateText(result, maxTokens)
231+
else -> result
232+
}
233+
}
234+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.xebia.functional.tokenizer
2+
3+
import com.goncalossilva.resources.Resource
4+
import io.kotest.assertions.withClue
5+
import io.kotest.matchers.ints.shouldBeLessThan
6+
import io.kotest.matchers.ints.shouldBeLessThanOrEqual
7+
import kotlin.test.Test
8+
9+
class EncodingTest {
10+
private val resource = Resource("src/commonTest/resources/cl100k_base_encodings.csv")
11+
private val ENCODING = EncodingType.CL100K_BASE.encoding
12+
13+
@Test
14+
fun truncateText() {
15+
resource.splitCSV().forEach { (input, _, _) ->
16+
val result = ENCODING.truncateText(input, 10)
17+
ENCODING.countTokens(result) shouldBeLessThanOrEqual 10
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)