diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index 483a61464..38268acfc 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -135,7 +135,10 @@ func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.Ren traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores") // 1. tokenize prompt - tokens := k.tokenizersPool.Tokenize(renderReq, prompt, modelName) + tokens, err := k.tokenizersPool.Tokenize(renderReq, prompt, modelName) + if err != nil { + return nil, fmt.Errorf("failed to tokenize: %w", err) + } // 2. get block keys blockKeys := k.tokensProcessor.TokensToKVBlockKeys(nil, tokens, modelName) diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index b1f9f0dd0..ad1927752 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -18,6 +18,7 @@ package tokenization import ( "context" + "errors" "fmt" "sync" @@ -62,6 +63,7 @@ func DefaultConfig() (*Config, error) { // tokenizationResponse holds the result of a tokenization operation. type tokenizationResponse struct { + err error Tokens []uint32 } @@ -146,7 +148,7 @@ func (pool *Pool) EnqueueTokenization(prompt, modelName string) { } // Tokenize queues a task and blocks until the final result is available. -func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string) []uint32 { +func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string) ([]uint32, error) { resultCh := make(chan tokenizationResponse, 1) pool.queue.Add(Task{ RenderReq: renderReq, @@ -156,8 +158,7 @@ func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, }) res := <-resultCh - tokens := res.Tokens - return tokens + return res.Tokens, res.err } // Run launches worker goroutines that process tasks until the context is @@ -184,7 +185,9 @@ func (pool *Pool) workerLoop(_ int) { } // Process the task. - if err := pool.processTask(task); err == nil { + var fatalErr FatalInitError + // If the error is fatal, remove the task from the queue. + if err := pool.processTask(task); err == nil || errors.As(err, &fatalErr) { pool.queue.Forget(task) } else { pool.queue.AddRateLimited(task) @@ -193,14 +196,36 @@ func (pool *Pool) workerLoop(_ int) { } } +// FatalInitError is an unrecoverable failure while initializing the target tokenizer. +type FatalInitError struct { + err error +} + +func (fe FatalInitError) Error() string { + return fmt.Sprintf("fatal init error: %s", fe.err.Error()) +} + +func (fe FatalInitError) Unwrap() error { + return fe.err +} // processTask tokenizes the prompt and updates the indexer. // It sends exactly one response (success or error) if ResultCh is provided. func (pool *Pool) processTask(task Task) error { + reportErr := func(task Task, err error) { + if task.ResultCh != nil { + // On failure, send the response if a channel is provided and close the channel. + resp := tokenizationResponse{err: err} + task.ResultCh <- resp + close(task.ResultCh) + } + } + if task.RenderReq != nil { var err error task.Prompt, err = pool.tokenizer.RenderChatTemplate(task.ModelName, task.RenderReq) if err != nil { log.Log.Error(err, "failed to render chat template", "modelName", task.ModelName) + reportErr(task, err) return err } } @@ -212,12 +237,14 @@ func (pool *Pool) processTask(task Task) error { tokens, offsets, err := pool.tokenizer.Encode(task.Prompt, task.ModelName) if err != nil { log.Log.Error(err, "failed to encode tokens", "prompt", task.Prompt, "modelName", task.ModelName) + reportErr(task, err) return err } // update the indexer with the new tokenization if e := pool.indexer.AddTokenization(task.Prompt, tokens, offsets); e != nil { err = fmt.Errorf("tokenization failed for model %s: %w", task.ModelName, e) + reportErr(task, err) return err } diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index a0d85ceb8..7cfcb7c83 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -83,6 +83,17 @@ func (m *MockIndexer) FindLongestContainedTokens(prompt string) ([]uint32, float return tokens, 0.0 } +type FailingMockIndexer struct { + MockIndexer +} + +//nolint:gocritic // unnamedResult: tokens and overlapRatio are self-explanatory from context +func (m *FailingMockIndexer) FindLongestContainedTokens(prompt string) ([]uint32, float64) { + args := m.Called(prompt) + tokens := args.Get(0).([]uint32) //nolint:errcheck // unused mock + return tokens, defaultMinPrefixOverlapRatio - .1 +} + func TestPool_ProcessTask(t *testing.T) { mockIndexer := &MockIndexer{} mockTokenizer := &MockTokenizer{} @@ -169,6 +180,58 @@ func TestPool_RunIntegration(t *testing.T) { mockIndexer.AssertExpectations(t) } +func TestPool_RunIntegrationFailed(t *testing.T) { + if testing.Short() { + t.Skip("Skipping tokenizer integration test in short mode") + } + + mockIndexer := &FailingMockIndexer{} + + prompts := []string{"hello world", "this is a test", "unicode test: 世界"} + + // Setup mock expectations for each prompt + for _, prompt := range prompts { + mockIndexer.On("FindLongestContainedTokens", prompt).Return([]uint32{}, defaultMinPrefixOverlapRatio-.1) + } + + // Setup a misconfigured tokenization pool + config := &Config{ + WorkersCount: 1, + LocalTokenizerConfig: &LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + testModelName: t.TempDir(), + }, + }, + MinPrefixOverlapRatio: defaultMinPrefixOverlapRatio, + } + + pool, err := NewTokenizationPool(config, mockIndexer) + require.NoError(t, err) + + // Create context for the pool + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Run pool + done := make(chan struct{}) + go func() { + defer close(done) + pool.Run(ctx) + }() + + // We expect all tokenizers to terminate reporting an error + for _, prompt := range prompts { + _, err := pool.Tokenize(nil, prompt, testModelName) + require.Error(t, err) + } + + time.Sleep(2 * time.Second) + cancel() + <-done + + mockIndexer.AssertExpectations(t) +} + func generateRandomSentence(wordLength, maxWords int, rng *rand.Rand) string { numWords := rng.Intn(maxWords) + 1 words := make([]string, numWords) @@ -269,7 +332,7 @@ func BenchmarkSyncTokenizationStress(b *testing.B) { for i := 0; b.Loop(); i++ { prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng) model := benchmarkModels[i%len(benchmarkModels)] - pool.Tokenize(nil, prompt, model) + _, _ = pool.Tokenize(nil, prompt, model) //nolint:errcheck } b.StopTimer() diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 603b32421..97f4540c2 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -400,7 +400,7 @@ func (t *CachedTokenizer) RenderChatTemplate( func (t *CachedTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { tokenizer, err := t.get(modelName) if err != nil { - return nil, nil, fmt.Errorf("failed to get tokenizer for model %q: %w", modelName, err) + return nil, nil, FatalInitError{err: fmt.Errorf("failed to get tokenizer for model %q: %w", modelName, err)} } encodeOptions := []tokenizers.EncodeOption{