Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pkg/kvcache/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.Ren
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores")

// 1. tokenize prompt
tokens := k.tokenizersPool.Tokenize(renderReq, prompt, modelName)
tokens, err := k.tokenizersPool.Tokenize(renderReq, prompt, modelName)
if err != nil {
return nil, fmt.Errorf("failed to tokenize: %w", err)
}

// 2. get block keys
blockKeys := k.tokensProcessor.TokensToKVBlockKeys(nil, tokens, modelName)
Expand Down
35 changes: 31 additions & 4 deletions pkg/tokenization/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import (
"context"
"errors"
"fmt"
"sync"

Expand Down Expand Up @@ -62,6 +63,7 @@

// tokenizationResponse holds the result of a tokenization operation.
type tokenizationResponse struct {
err error
Tokens []uint32
}

Expand Down Expand Up @@ -146,7 +148,7 @@
}

// Tokenize queues a task and blocks until the final result is available.
func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string) []uint32 {
func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string) ([]uint32, error) {
resultCh := make(chan tokenizationResponse, 1)
pool.queue.Add(Task{
RenderReq: renderReq,
Expand All @@ -156,8 +158,7 @@
})

res := <-resultCh
tokens := res.Tokens
return tokens
return res.Tokens, res.err
}

// Run launches worker goroutines that process tasks until the context is
Expand All @@ -184,7 +185,9 @@
}

// Process the task.
if err := pool.processTask(task); err == nil {
var fatalErr FatalInitError
// If the error is fatal, remove the task from the queue.
if err := pool.processTask(task); err == nil || errors.As(err, &fatalErr) {
pool.queue.Forget(task)
} else {
pool.queue.AddRateLimited(task)
Expand All @@ -193,14 +196,36 @@
}
}

// FatalInitError is an unrecoverable failure while initializing the target tokenizer.
type FatalInitError struct {
err error
}

func (fe FatalInitError) Error() string {
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Error method could panic if fe.err is nil. While this may not happen in normal operation, defensive programming suggests adding a nil check to prevent potential panics.

Suggested change
func (fe FatalInitError) Error() string {
func (fe FatalInitError) Error() string {
if fe.err == nil {
return "fatal init error: <nil>"
}

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will happen in practice... 🤔

return fmt.Sprintf("fatal init error: %s", fe.err.Error())
}

Comment thread
evacchi marked this conversation as resolved.
func (fe FatalInitError) Unwrap() error {
return fe.err
}

Check failure on line 210 in pkg/tokenization/pool.go

View workflow job for this annotation

GitHub Actions / lint-and-test

File is not properly formatted (goimports)

Check failure on line 210 in pkg/tokenization/pool.go

View workflow job for this annotation

GitHub Actions / lint-and-test

File is not properly formatted (gofumpt)
// processTask tokenizes the prompt and updates the indexer.
// It sends exactly one response (success or error) if ResultCh is provided.
func (pool *Pool) processTask(task Task) error {
reportErr := func(task Task, err error) {
if task.ResultCh != nil {
// On failure, send the response if a channel is provided and close the channel.
resp := tokenizationResponse{err: err}
task.ResultCh <- resp
close(task.ResultCh)
}
}

if task.RenderReq != nil {
var err error
task.Prompt, err = pool.tokenizer.RenderChatTemplate(task.ModelName, task.RenderReq)
if err != nil {
log.Log.Error(err, "failed to render chat template", "modelName", task.ModelName)
reportErr(task, err)
return err
}
}
Expand All @@ -212,12 +237,14 @@
tokens, offsets, err := pool.tokenizer.Encode(task.Prompt, task.ModelName)
if err != nil {
log.Log.Error(err, "failed to encode tokens", "prompt", task.Prompt, "modelName", task.ModelName)
reportErr(task, err)
return err
}

// update the indexer with the new tokenization
if e := pool.indexer.AddTokenization(task.Prompt, tokens, offsets); e != nil {
err = fmt.Errorf("tokenization failed for model %s: %w", task.ModelName, e)
reportErr(task, err)
return err
}

Expand Down
65 changes: 64 additions & 1 deletion pkg/tokenization/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@
return tokens, 0.0
}

type FailingMockIndexer struct {
MockIndexer
}

//nolint:gocritic // unnamedResult: tokens and overlapRatio are self-explanatory from context
func (m *FailingMockIndexer) FindLongestContainedTokens(prompt string) ([]uint32, float64) {
args := m.Called(prompt)
tokens := args.Get(0).([]uint32) //nolint:errcheck // unused mock
return tokens, defaultMinPrefixOverlapRatio - .1
}

func TestPool_ProcessTask(t *testing.T) {
mockIndexer := &MockIndexer{}
mockTokenizer := &MockTokenizer{}
Expand Down Expand Up @@ -169,6 +180,58 @@
mockIndexer.AssertExpectations(t)
}

func TestPool_RunIntegrationFailed(t *testing.T) {
if testing.Short() {
t.Skip("Skipping tokenizer integration test in short mode")
}

mockIndexer := &FailingMockIndexer{}

prompts := []string{"hello world", "this is a test", "unicode test: 世界"}

// Setup mock expectations for each prompt
for _, prompt := range prompts {
mockIndexer.On("FindLongestContainedTokens", prompt).Return([]uint32{}, defaultMinPrefixOverlapRatio-.1)
}

// Setup a misconfigured tokenization pool
config := &Config{
WorkersCount: 1,
LocalTokenizerConfig: &LocalTokenizerConfig{
ModelTokenizerMap: map[string]string{
testModelName: t.TempDir(),
},
},
MinPrefixOverlapRatio: defaultMinPrefixOverlapRatio,
}

pool, err := NewTokenizationPool(config, mockIndexer)
require.NoError(t, err)

// Create context for the pool
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Run pool
done := make(chan struct{})
go func() {
defer close(done)
pool.Run(ctx)
}()

// We expect all tokenizers to terminate reporting an error
for _, prompt := range prompts {
_, err := pool.Tokenize(nil, prompt, testModelName)
require.Error(t, err)
}

time.Sleep(2 * time.Second)
cancel()
<-done

mockIndexer.AssertExpectations(t)
}

func generateRandomSentence(wordLength, maxWords int, rng *rand.Rand) string {
numWords := rng.Intn(maxWords) + 1
words := make([]string, numWords)
Expand Down Expand Up @@ -269,7 +332,7 @@
for i := 0; b.Loop(); i++ {
prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng)
model := benchmarkModels[i%len(benchmarkModels)]
pool.Tokenize(nil, prompt, model)
_, _ = pool.Tokenize(nil, prompt, model) //nolint:errcheck

Check failure on line 335 in pkg/tokenization/pool_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

whyNoLint: include an explanation for nolint directive (gocritic)
}

b.StopTimer()
Expand Down
2 changes: 1 addition & 1 deletion pkg/tokenization/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (t *CachedTokenizer) RenderChatTemplate(
func (t *CachedTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) {
tokenizer, err := t.get(modelName)
if err != nil {
return nil, nil, fmt.Errorf("failed to get tokenizer for model %q: %w", modelName, err)
return nil, nil, FatalInitError{err: fmt.Errorf("failed to get tokenizer for model %q: %w", modelName, err)}
}

encodeOptions := []tokenizers.EncodeOption{
Expand Down
Loading