From 06c95b47a17264b98e7c081bacb2804a0eb64c77 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 18 Dec 2025 16:38:47 +0000 Subject: [PATCH 1/5] change to vllm --- .gitignore | 6 +- docs/architecture.md | 2 +- examples/kv_events/online/main.go | 22 +- examples/testdata/data.go | 2 +- go.mod | 9 +- go.sum | 18 + pkg/kvcache/indexer.go | 8 +- pkg/preprocessing/chat_completions/README.md | 30 +- .../chat_completions/cgo_functions.c | 231 +++++------ .../chat_completions/cgo_functions.go | 135 ++++--- .../chat_completions/cgo_functions.h | 18 +- .../chat_completions/cgo_functions_test.go | 379 +++++++++--------- .../render_jinja_template_wrapper.py | 259 ------------ .../chat_completions/requirements.txt | 8 +- .../chat_completions/tokenizer_wrapper.py | 195 +++++++++ pkg/tokenization/pool.go | 58 +-- pkg/tokenization/pool_test.go | 62 ++- pkg/tokenization/tokenizer.go | 228 +++++------ pkg/tokenization/tokenizer_test.go | 112 ++++-- pkg/tokenization/uds_tokenizer.go | 16 +- tests/e2e/redis_mock/e2e_suite_test.go | 30 +- tests/e2e/redis_mock/e2e_test.go | 360 +++++++++-------- 22 files changed, 1091 insertions(+), 1097 deletions(-) delete mode 100644 pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py create mode 100644 pkg/preprocessing/chat_completions/tokenizer_wrapper.py diff --git a/.gitignore b/.gitignore index 5a2cede74..b9df47691 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,8 @@ _cgo_* /hack/tools # Tokenizer binaries -/lib \ No newline at end of file +/lib + +# UDS tokenizer files +services/uds_tokenizer/models/* +!services/uds_tokenizer/models/README.md \ No newline at end of file diff --git a/docs/architecture.md b/docs/architecture.md index 14f2d20d6..8dadfa741 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -179,4 +179,4 @@ The Indexer relies on several libraries and tools: * Used for the event processing pool and communication between components. * Requires `libzmq` library to be installed on the system. * **Python**: Required to run a CGO binding for the `chat_completions_template` package. - * Used for jinja2 templating of chat completions requests. \ No newline at end of file + * Used for vllm templating of chat completions requests. \ No newline at end of file diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index bfb31d59d..f66652ddb 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -61,7 +61,7 @@ const ( // ChatCompletionsRequest holds the fields needed for chat-completions rendering. type ChatCompletionsRequest struct { Model string `json:"model"` - *preprocessing.RenderJinjaTemplateRequest + *preprocessing.ApplyChatTemplateRequest } func main() { @@ -318,34 +318,18 @@ func setupUnifiedHTTPEndpoints( logger.Info("Created ChatCompletions", "req", req) // Get chat template for the model if not provided - if req.ChatTemplate == "" { - templateReq := preprocessing.FetchChatTemplateRequest{ - Model: req.Model, - Token: os.Getenv(envHFToken), - } - - var err error - req.ChatTemplate, req.ChatTemplateKWArgs, err = chatTemplatingProcessor.FetchChatTemplate(ctx, templateReq) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to get chat template: %v", err), http.StatusInternalServerError) - return - } - } - - response, err := chatTemplatingProcessor.RenderChatTemplate(ctx, req.RenderJinjaTemplateRequest) + renderedPrompt, err := chatTemplatingProcessor.ApplyChatTemplate(ctx, req.ApplyChatTemplateRequest) if err != nil { http.Error(w, fmt.Sprintf("Failed to render chat template: %v", err), http.StatusInternalServerError) return } // Use KV-cache to score the rendered template - if len(response.RenderedChats) == 0 { + if renderedPrompt == "" { http.Error(w, "No rendered chats found in response", http.StatusInternalServerError) return } - renderedPrompt := response.RenderedChats[0] - // Get score pods, err := kvCacheIndexer.GetPodScores(ctx, nil, renderedPrompt, req.Model, nil) if err != nil { diff --git a/examples/testdata/data.go b/examples/testdata/data.go index 1806b972c..cdc5e6f35 100644 --- a/examples/testdata/data.go +++ b/examples/testdata/data.go @@ -24,7 +24,7 @@ const ( ModelName = "bert-base-uncased" ) -var RenderReq *preprocessing.RenderJinjaTemplateRequest = nil +var RenderReq *preprocessing.ApplyChatTemplateRequest = nil //go:embed prompt.txt var Prompt string diff --git a/go.mod b/go.mod index 1940cbcbd..12bf6f95c 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/vmihailenco/msgpack/v5 v5.4.1 go.uber.org/multierr v1.11.0 + go.uber.org/zap v1.27.0 golang.org/x/net v0.38.0 google.golang.org/grpc v1.68.1 google.golang.org/protobuf v1.36.5 @@ -31,11 +32,14 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/evanphx/json-patch/v5 v5.9.11 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/btree v1.1.3 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -49,21 +53,24 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect - go.uber.org/zap v1.27.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect + golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.30.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.9.0 // indirect + gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.33.0 // indirect + k8s.io/apiextensions-apiserver v0.33.0 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect diff --git a/go.sum b/go.sum index 1bd696c3c..c3450867d 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21j github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -25,6 +27,12 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= +github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= +github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= @@ -45,12 +53,16 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= +github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/gnostic-models v0.6.9 h1:MU/8wDLif2qCXZmzncUQ/BOfxWfthHi63KqpoNbWqVw= github.com/google/gnostic-models v0.6.9/go.mod h1:CiWsm0s6BSQd1hRn8/QmxqB6BesYcbSZxsz9b0KuDBw= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -151,6 +163,8 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -174,6 +188,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= +gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY= google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= @@ -192,6 +208,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.33.0 h1:yTgZVn1XEe6opVpP1FylmNrIFWuDqe2H0V8CT5gxfIU= k8s.io/api v0.33.0/go.mod h1:CTO61ECK/KU7haa3qq8sarQ0biLq2ju405IZAd9zsiM= +k8s.io/apiextensions-apiserver v0.33.0 h1:d2qpYL7Mngbsc1taA4IjJPRJ9ilnsXIrndH+r9IimOs= +k8s.io/apiextensions-apiserver v0.33.0/go.mod h1:VeJ8u9dEEN+tbETo+lFkwaaZPg6uFKLGj5vyNEwwSzc= k8s.io/apimachinery v0.33.0 h1:1a6kHrJxb2hs4t8EE5wuR/WxKDwGN1FKH3JvDtA0CIQ= k8s.io/apimachinery v0.33.0/go.mod h1:BHW0YOu7n22fFv/JkYOEfkUYNRN0fj0BlvMFWA7b+SM= k8s.io/client-go v0.33.0 h1:UASR0sAYVUzs2kYuKn/ZakZlcs2bEHaizrrHUZg0G98= diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index d0642c483..3193b25dc 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -134,13 +134,13 @@ func (k *Indexer) KVBlockIndex() kvblock.Index { // relevant. // // The function returns a map of pod identifiers to scores. -func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string, +func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.ApplyChatTemplateRequest, prompt, modelName string, podIdentifiers []string, ) (map[string]float64, error) { traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores") // 1. tokenize prompt - tokens := k.tokenizersPool.Tokenize(renderReq, prompt) + tokens := k.tokenizersPool.Tokenize(renderReq, prompt, modelName) // 2. get block keys blockKeys := k.tokensProcessor.TokensToKVBlockKeys(nil, tokens, modelName) @@ -183,7 +183,3 @@ func podsPerKeyPrintHelper(ks map[kvblock.Key][]kvblock.PodEntry) string { return flattened } - -func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName string) { - k.tokenizersPool.SetTokenizer(tokenizer, modelName) -} diff --git a/pkg/preprocessing/chat_completions/README.md b/pkg/preprocessing/chat_completions/README.md index c3b29c100..a3103b36f 100644 --- a/pkg/preprocessing/chat_completions/README.md +++ b/pkg/preprocessing/chat_completions/README.md @@ -46,10 +46,10 @@ The following is the major request structure used for templating: - Some fields are provided by the router (serving vLLM's OpenAI-compatible API). - Some fields are fetched from the model's tokenizer (e.g., chat template). -The `RenderJinjaTemplateRequest` matches the `transformers` library's `ChatTemplateRequest` structure, which is used to render the chat template. +The `ApplyChatTemplateRequest` matches the `transformers` library's `ChatTemplateRequest` structure, which is used to render the chat template. -**RenderJinjaTemplateRequest accepts these fields, that match the `render_jinja_template`'s expected parameters:** -- `Conversations` - List of message lists (role/content pairs) +**ApplyChatTemplateRequest accepts these fields, that match the `apply_chat_template`'s expected parameters:** +- `Conversation` - List of message lists (role/content pairs) - `Tools` - (Optional) List of tool schemas - `Documents` - (Optional) List of document dicts - `ChatTemplate` - (Optional) Override for the chat template @@ -70,25 +70,17 @@ The templating process (steps 1.1-1.4) handles the conversion from structured re └── cgo_functions.go:NewChatTemplatingProcessor() └── Creates ChatTemplatingProcessor struct with initialized=false -1.2. **ChatTemplate Fetching**: wrapper.FetchChatTemplate(ctx, getReq) - ├── cgo_functions.go:FetchChatTemplate(ctx, req) - │ ├── Initialize() Python interpreter via CGO - │ ├── executePythonCode() - **CGO Binding** to Python - │ └── **Python Wrapper**: render_jinja_template_wrapper.py:get_model_chat_template() - │ └── Uses Hugging Face AutoTokenizer to fetch model template - └── Returns: (template, template_vars) - -1.3. **ChatTemplate Rendering**: wrapper.RenderChatTemplate(ctx, req) +1.2. **ChatTemplate Rendering**: wrapper.RenderChatTemplate(ctx, req) ├── cgo_functions.go:RenderChatTemplate(ctx, req) │ ├── Initialize() Python interpreter via CGO (if not already done) │ ├── executePythonCode() - **CGO Binding** to Python - │ └── **Python Wrapper**: render_jinja_template_wrapper.py:render_jinja_template() - │ └── Imports render_jinja_template from transformers.utils.chat_template_utils - │ └── Uses transformers library's core template rendering functionality - └── Returns: RenderJinjaTemplateResponse + │ └── **Python Wrapper**: tokenizer_wrapper.py:apply_chat_template() + │ └── Imports apply_chat_template from vllm.transformers_utils.tokenizer + │ └── Uses vllm library's core template rendering functionality + └── Returns: String -1.4. **Extract Flattened Prompt** - └── prompt := resp.RenderedChats[0] +1.3. **Extract Flattened Prompt** + └── prompt := response └── Continue with existing pipeline: Tokenize → KV Block Keys → Pod Scoring ``` ### Optimized Preprocessing Architecture @@ -100,7 +92,7 @@ The templating process (steps 1.1-1.4) handles the conversion from structured re - **Thread-Safe Initialization**: Global locks prevent multiple initializations ##### **Function Caching** -- **Cached Python Functions**: `render_jinja_template` and `get_model_chat_template` cached globally +- **Cached Python Functions**: `apply_chat_template` and `encode` cached globally - **Module-Level Caching**: Python modules imported once and reused - **Thread Safety**: GIL management for concurrent access diff --git a/pkg/preprocessing/chat_completions/cgo_functions.c b/pkg/preprocessing/chat_completions/cgo_functions.c index 9eef79151..57d732250 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.c +++ b/pkg/preprocessing/chat_completions/cgo_functions.c @@ -14,20 +14,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include // for getpid() and usleep() +#include // for getpid() and usleep() #include "cgo_functions.h" // Global variables for caching PyObject* g_chat_template_module = NULL; -PyObject* g_render_jinja_template_func = NULL; -PyObject* g_get_model_chat_template_func = NULL; +PyObject* g_apply_chat_template_func = NULL; +PyObject* g_encode_func = NULL; int g_initialized = 0; int g_python_initialized = 0; // Process-level global initialization tracking static int g_process_initialized = 0; -static int g_finalized = 0; +static int g_finalized = 0; static pid_t g_init_pid = 0; // Thread safety for initialization @@ -42,7 +42,8 @@ int Py_InitializeGo() { if (g_process_initialized) { if (g_init_pid != getpid()) { printf("[C] Py_InitializeGo WARNING - Different PID trying to initialize (init_pid: %d, current_pid: %d)\n", g_init_pid, getpid()); - } else { + } + else { printf("[C] Py_InitializeGo - Already initialized in this process (PID: %d)\n", getpid()); } return 0; @@ -95,24 +96,24 @@ void Py_FinalizeGo() { printf("[C] Py_FinalizeGo - Already finalized, skipping\n"); return; } - + // Mark as finalized first to prevent race conditions g_finalized = 1; - + // Clean up module references safely - if (g_render_jinja_template_func) { - Py_DECREF(g_render_jinja_template_func); - g_render_jinja_template_func = NULL; - } - if (g_get_model_chat_template_func) { - Py_DECREF(g_get_model_chat_template_func); - g_get_model_chat_template_func = NULL; + if (g_apply_chat_template_func) { + Py_DECREF(g_apply_chat_template_func); + g_apply_chat_template_func = NULL; } if (g_chat_template_module) { Py_DECREF(g_chat_template_module); g_chat_template_module = NULL; } - + if (g_encode_func) { + Py_DECREF(g_encode_func); + g_encode_func = NULL; + } + // Reset state without finalizing Python // Python will be cleaned up when the process exits g_python_initialized = 0; @@ -149,7 +150,7 @@ const char* PyUnicode_AsGoString(PyObject* obj) { // Initialize the cached module and functions (call once at startup) int Py_InitChatTemplateModule() { - + // Thread-safe initialization check if (g_init_lock == NULL) { g_init_lock = PyThread_allocate_lock(); @@ -158,38 +159,38 @@ int Py_InitChatTemplateModule() { return -1; } } - + PyThread_acquire_lock(g_init_lock, NOWAIT_LOCK); - + // Check if already initialized if (g_initialized) { printf("[C] Py_InitChatTemplateModule - Already initialized globally, returning\n"); PyThread_release_lock(g_init_lock); return 0; } - + // Ensure Python is initialized if (!g_python_initialized) { printf("[C] Py_InitChatTemplateModule ERROR - Python not initialized\n"); PyThread_release_lock(g_init_lock); return -1; } - + // Acquire GIL for module initialization PyGILState_STATE gil_state = PyGILState_Ensure(); - - + + // Import the chat template wrapper module AFTER setting up the path - g_chat_template_module = PyImport_ImportModule("render_jinja_template_wrapper"); + g_chat_template_module = PyImport_ImportModule("tokenizer_wrapper"); if (!g_chat_template_module) { - printf("[C] Py_InitChatTemplateModule ERROR - Failed to import render_jinja_template_wrapper module\n"); + printf("[C] Py_InitChatTemplateModule ERROR - Failed to import tokenizer_wrapper module\n"); PyErr_Print(); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - + // Get the module dictionary PyObject* module_dict = PyModule_GetDict(g_chat_template_module); if (!module_dict) { @@ -198,30 +199,30 @@ int Py_InitChatTemplateModule() { PyThread_release_lock(g_init_lock); return -1; } - - // Get the render_jinja_template function - g_render_jinja_template_func = PyDict_GetItemString(module_dict, "render_jinja_template"); - if (!g_render_jinja_template_func || !PyCallable_Check(g_render_jinja_template_func)) { - printf("[C] Py_InitChatTemplateModule ERROR - render_jinja_template function not found or not callable\n"); + + // Get the apply_chat_template function + g_apply_chat_template_func = PyDict_GetItemString(module_dict, "apply_chat_template"); + if (!g_apply_chat_template_func || !PyCallable_Check(g_apply_chat_template_func)) { + printf("[C] Py_InitChatTemplateModule ERROR - apply_chat_template function not found or not callable\n"); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - Py_INCREF(g_render_jinja_template_func); // Keep a reference - - // Get the get_model_chat_template function - g_get_model_chat_template_func = PyDict_GetItemString(module_dict, "get_model_chat_template"); - if (!g_get_model_chat_template_func || !PyCallable_Check(g_get_model_chat_template_func)) { - printf("[C] Py_InitChatTemplateModule ERROR - get_model_chat_template function not found or not callable\n"); + Py_INCREF(g_apply_chat_template_func); // Keep a reference + + // Get the encode function + g_encode_func = PyDict_GetItemString(module_dict, "encode"); + if (!g_encode_func || !PyCallable_Check(g_encode_func)) { + printf("[C] Py_InitChatTemplateModule ERROR - encode function not found or not callable\n"); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - Py_INCREF(g_get_model_chat_template_func); // Keep a reference - + Py_INCREF(g_encode_func); // Keep a reference + // Release GIL PyGILState_Release(gil_state); - + g_initialized = 1; PyThread_release_lock(g_init_lock); return 0; @@ -229,38 +230,38 @@ int Py_InitChatTemplateModule() { -// Call the cached render_jinja_template function -char* Py_CallRenderJinjaTemplate(const char* json_request) { +// Call the cached apply_chat_template function +char* Py_CallApplyChatTemplate(const char* json_request) { // Try direct call first (fast path) - char* result = Py_CallRenderJinjaTemplateInternal(json_request); + char* result = Py_CallApplyChatTemplateInternal(json_request); if (result != NULL) { return result; // Success on first try } - + // If failed, just return NULL (no retry, no reload) - return NULL; + return NULL; } // Internal function that does the actual work -char* Py_CallRenderJinjaTemplateInternal(const char* json_request) { +char* Py_CallApplyChatTemplateInternal(const char* json_request) { // Check if Python interpreter is still valid if (!Py_IsInitialized()) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Python interpreter not initialized\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Python interpreter not initialized\n"); return NULL; } - + // Simple validation if (!json_request) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Input is NULL\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Input is NULL\n"); return NULL; } - + // Acquire GIL for Python operations - PyGILState_STATE gil_state = PyGILState_Ensure(); + PyGILState_STATE gil_state = PyGILState_Ensure(); // Create Python string from JSON request PyObject* py_json = PyUnicode_FromString(json_request); if (!py_json) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to create Python string\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to create Python string\n"); PyGILState_Release(gil_state); return NULL; } @@ -268,134 +269,138 @@ char* Py_CallRenderJinjaTemplateInternal(const char* json_request) { // Create arguments tuple PyObject* args = PyTuple_Pack(1, py_json); if (!args) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to create args tuple\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to create args tuple\n"); Py_DECREF(py_json); PyGILState_Release(gil_state); return NULL; - } + } // Call the cached function - PyObject* py_result = PyObject_CallObject(g_render_jinja_template_func, args); - + PyObject* py_result = PyObject_CallObject(g_apply_chat_template_func, args); + // Clean up args Py_DECREF(args); Py_DECREF(py_json); - + char* cresult = NULL; if (py_result) { // Convert to C string const char* s = PyUnicode_AsUTF8(py_result); if (s) { cresult = strdup(s); - } else { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to convert result to C string\n"); + } + else { + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to convert result to C string\n"); } Py_DECREF(py_result); - } else { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Python function returned NULL\n"); + } + else { + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Python function returned NULL\n"); PyErr_Print(); fflush(stderr); } - + // Release GIL PyGILState_Release(gil_state); - + return cresult; } -// Call the cached get_model_chat_template function -char* Py_CallGetModelChatTemplate(const char* json_request) { +// Call the cached encode function +char* Py_CallEncode(const char* json_request) { // Try direct call first (fast path) - char* result = Py_CallGetModelChatTemplateInternal(json_request); + char* result = Py_CallEncodeInternal(json_request); if (result != NULL) { return result; // Success on first try } - + // If failed, just return NULL (no retry, no reload) return NULL; } // Internal function that does the actual work -char* Py_CallGetModelChatTemplateInternal(const char* json_request) { +char* Py_CallEncodeInternal(const char* json_request) { // Check if Python is initialized if (!g_python_initialized) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Python not initialized\n"); + printf("[C] Py_CallEncodeInternal ERROR - Python not initialized\n"); fflush(stdout); return NULL; } - + // Validate cached function - if (!g_get_model_chat_template_func) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Cached function is NULL\n"); + if (!g_encode_func) { + printf("[C] Py_CallEncodeInternal ERROR - Cached function is NULL\n"); fflush(stdout); return NULL; } - + // Validate that the cached function is still a valid Python object fflush(stdout); - if (!PyCallable_Check(g_get_model_chat_template_func)) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Cached function is not callable (corrupted?)\n"); + if (!PyCallable_Check(g_encode_func)) { + printf("[C] Py_CallEncodeInternal ERROR - Cached function is not callable (corrupted?)\n"); fflush(stdout); return NULL; } - + // Validate input if (!json_request) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Input is NULL\n"); + printf("[C] Py_CallEncodeInternal ERROR - Input is NULL\n"); fflush(stdout); return NULL; } - + // Acquire GIL for Python operations PyGILState_STATE gil_state = PyGILState_Ensure(); - + // Create Python string from JSON request PyObject* py_json = PyUnicode_FromString(json_request); if (!py_json) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to create Python string\n"); + printf("[C] Py_CallEncodeInternal ERROR - Failed to create Python string\n"); fflush(stdout); PyGILState_Release(gil_state); return NULL; } - + // Create arguments tuple PyObject* args = PyTuple_Pack(1, py_json); if (!args) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to create args tuple\n"); + printf("[C] Py_CallEncodeInternal ERROR - Failed to create args tuple\n"); fflush(stdout); Py_DECREF(py_json); PyGILState_Release(gil_state); return NULL; } - + // Call the cached function - PyObject* py_result = PyObject_CallObject(g_get_model_chat_template_func, args); - + PyObject* py_result = PyObject_CallObject(g_encode_func, args); + // Clean up args Py_DECREF(args); Py_DECREF(py_json); - + char* cresult = NULL; if (py_result) { // Convert to C string const char* s = PyUnicode_AsUTF8(py_result); if (s) { cresult = strdup(s); - } else { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to convert result to C string\n"); + } + else { + printf("[C] Py_CallEncodeInternal ERROR - Failed to convert result to C string\n"); fflush(stdout); } Py_DECREF(py_result); - } else { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Python function returned NULL\n"); + } + else { + printf("[C] Py_CallEncodeInternal ERROR - Python function returned NULL\n"); fflush(stdout); PyErr_Print(); fflush(stderr); } - + // Release GIL PyGILState_Release(gil_state); - + return cresult; } @@ -405,9 +410,9 @@ char* Py_ClearCaches() { printf("[C] Py_ClearCaches ERROR - Module not initialized\n"); return NULL; } - + PyGILState_STATE gil_state = PyGILState_Ensure(); - + // Call the clear_caches function PyObject* clear_caches_func = PyDict_GetItemString(PyModule_GetDict(g_chat_template_module), "clear_caches"); if (!clear_caches_func || !PyCallable_Check(clear_caches_func)) { @@ -415,7 +420,7 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + PyObject* result = PyObject_CallObject(clear_caches_func, NULL); if (!result) { printf("[C] Py_ClearCaches ERROR - Failed to call clear_caches function\n"); @@ -423,7 +428,7 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + // Convert result to C string const char* result_str = PyUnicode_AsUTF8(result); if (!result_str) { @@ -432,11 +437,11 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + char* c_result = strdup(result_str); Py_DECREF(result); PyGILState_Release(gil_state); - + return c_result; } @@ -444,50 +449,50 @@ char* Py_ClearCaches() { void Py_CleanupChatTemplateModule() { if (g_initialized && Py_IsInitialized()) { PyGILState_STATE state = PyGILState_Ensure(); - Py_XDECREF(g_render_jinja_template_func); - Py_XDECREF(g_get_model_chat_template_func); + Py_XDECREF(g_apply_chat_template_func); + Py_XDECREF(g_encode_func); Py_XDECREF(g_chat_template_module); - g_render_jinja_template_func = NULL; - g_get_model_chat_template_func = NULL; + g_apply_chat_template_func = NULL; + g_encode_func = NULL; g_chat_template_module = NULL; g_initialized = 0; PyGILState_Release(state); } -} +} // Re-initialize Python interpreter state -int Py_ReinitializeGo() { +int Py_ReinitializeGo() { // Reset global flags g_initialized = 0; g_python_initialized = 0; g_process_initialized = 0; - + // Clean up cached objects - if (g_render_jinja_template_func) { - Py_DECREF(g_render_jinja_template_func); - g_render_jinja_template_func = NULL; + if (g_apply_chat_template_func) { + Py_DECREF(g_apply_chat_template_func); + g_apply_chat_template_func = NULL; } - if (g_get_model_chat_template_func) { - Py_DECREF(g_get_model_chat_template_func); - g_get_model_chat_template_func = NULL; + if (g_encode_func) { + Py_DECREF(g_encode_func); + g_encode_func = NULL; } if (g_chat_template_module) { Py_DECREF(g_chat_template_module); g_chat_template_module = NULL; } - + // Re-initialize int result = Py_InitializeGo(); if (result != 0) { printf("[C] Py_ReinitializeGo ERROR - Failed to re-initialize Python\n"); return result; } - + result = Py_InitChatTemplateModule(); if (result != 0) { printf("[C] Py_ReinitializeGo ERROR - Failed to re-initialize chat template module\n"); return result; } - + return 0; } diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index 68752d1c1..c2600462f 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -32,18 +32,28 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" "sigs.k8s.io/controller-runtime/pkg/log" ) +import "github.com/daulet/tokenizers" -// ChatMessage represents a single message in a conversation. -type ChatMessage struct { +// Conversation represents a single message in a conversation. +type Conversation struct { Role string `json:"role"` Content string `json:"content"` } -// RenderJinjaTemplateRequest represents the request to render a chat template. -type RenderJinjaTemplateRequest struct { - // `conversations` is the transformers name, but we use `messages` for consistency with OpenAI API. +type ChatTemplateRequest struct { + IsLocal bool `json:"is_local,omitempty"` + DownloadDir string `json:"download_dir,omitempty"` + Model string `json:"model"` + Revision string `json:"revision,omitempty"` + Token string `json:"token,omitempty"` +} + +// ApplyChatTemplateRequest represents the request to render a chat template. +type ApplyChatTemplateRequest struct { + // `conversation` is the transformers name, but we use `messages` for consistency with OpenAI API. // The Python wrapper will handle converting this to a batched list if needed. - Conversations []ChatMessage `json:"messages"` + ChatTemplateRequest + Conversation []Conversation `json:"conversation"` Tools []interface{} `json:"tools,omitempty"` Documents []interface{} `json:"documents,omitempty"` ChatTemplate string `json:"chat_template,omitempty"` @@ -53,13 +63,19 @@ type RenderJinjaTemplateRequest struct { ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` } -// DeepCopy creates a deep copy of the RenderJinjaTemplateRequest. -func (req *RenderJinjaTemplateRequest) DeepCopy() (*RenderJinjaTemplateRequest, error) { +type EncodeRequest struct { + ChatTemplateRequest + Text string `json:"text"` + AddSpecialTokens bool `json:"add_special_tokens,omitempty"` +} + +// DeepCopy creates a deep copy of the ApplyChatTemplateRequest. +func (req *ApplyChatTemplateRequest) DeepCopy() (*ApplyChatTemplateRequest, error) { b, err := json.Marshal(req) if err != nil { return nil, err } - var out RenderJinjaTemplateRequest + var out ApplyChatTemplateRequest err = json.Unmarshal(b, &out) if err != nil { return nil, err @@ -67,35 +83,28 @@ func (req *RenderJinjaTemplateRequest) DeepCopy() (*RenderJinjaTemplateRequest, return &out, nil } -// RenderJinjaTemplateResponse represents the response from rendering a chat template. -type RenderJinjaTemplateResponse struct { - RenderedChats []string `json:"rendered_chats"` - GenerationIndices [][][]int `json:"generation_indices"` -} - -// FetchChatTemplateRequest represents the request to fetch a chat template. -// This is needed if the fields are not set in the `RenderJinjaTemplateRequest`. -// When called, it will fetch the `chat_template` from the tokenizer. -// If the tokenizer is not present, it will be fetched from HuggingFace using -// the `token` if provided. -type FetchChatTemplateRequest struct { - Model string `json:"model"` - ChatTemplate string `json:"chat_template,omitempty"` - Tools []interface{} `json:"tools,omitempty"` - Revision string `json:"revision,omitempty"` - Token string `json:"token,omitempty"` - IsLocalPath bool `json:"is_local_path,omitempty"` +// DeepCopy creates a deep copy of the EncodeRequest. +func (req *EncodeRequest) DeepCopy() (*EncodeRequest, error) { + b, err := json.Marshal(req) + if err != nil { + return nil, err + } + var out EncodeRequest + err = json.Unmarshal(b, &out) + if err != nil { + return nil, err + } + return &out, nil } -// FetchChatTemplateResponse represents the response from fetching a chat template. -type FetchChatTemplateResponse struct { - ChatTemplate string `json:"chat_template,omitempty"` - ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` +type EncodeResponse struct { + TokenIDs []uint32 `json:"input_ids"` + OffsetMappings []tokenizers.Offset `json:"offset_mapping"` } // ChatTemplatingProcessor is a processor that handles chat template rendering // using a cached Python function. Once the Python interpreter is initialized, -// it caches the `transformers` function `render_jinja_template` for rendering +// it caches the `vllm` function `apply_chat_template` for rendering // chat templates. It also provides a method to fetch chat templates from the // tokenizer or HuggingFace if the tokenizer is not present. type ChatTemplatingProcessor struct{} @@ -128,76 +137,66 @@ func (w *ChatTemplatingProcessor) Finalize() { C.Py_FinalizeGo() } -// RenderChatTemplate renders a chat template using the cached Python function. -// It calls the Python `transformers` function `render_jinja_template` with the provided request. -// -//nolint:gocritic // hugeParam: req is passed by value intentionally for immutability, but can consider using pointer. -func (w *ChatTemplatingProcessor) RenderChatTemplate(ctx context.Context, - req *RenderJinjaTemplateRequest, -) (*RenderJinjaTemplateResponse, error) { - traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("RenderChatTemplate") +// ApplyChatTemplate renders a chat template using the cached Python function. +// It calls the Python `vllm` function `apply_chat_template` with the provided request. +func (w *ChatTemplatingProcessor) ApplyChatTemplate(ctx context.Context, + req *ApplyChatTemplateRequest, +) (string, error) { + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("ApplyChatTemplate") if req == nil { traceLogger.Error(nil, "Received nil request") - return nil, fmt.Errorf("received nil request") + return "", fmt.Errorf("received nil request") } - // Convert request to JSON reqJSON, err := json.Marshal(req) if err != nil { traceLogger.Error(err, "Failed to marshal request") - return nil, fmt.Errorf("failed to marshal request: %w", err) + return "", fmt.Errorf("failed to marshal request: %w", err) } + traceLogger.Info("Applying chat template", "req", string(reqJSON)) // Call the cached Python function - cResult := C.Py_CallRenderJinjaTemplate(C.CString(string(reqJSON))) + cResult := C.Py_CallApplyChatTemplate(C.CString(string(reqJSON))) if cResult == nil { traceLogger.Error(nil, "C function returned nil") - return nil, fmt.Errorf("python render_jinja_template failed") + return "", fmt.Errorf("python apply_chat_template failed") } defer C.free(unsafe.Pointer(cResult)) - resultJSON := C.GoString(cResult) - - // Parse the response - var response RenderJinjaTemplateResponse - if err := json.Unmarshal([]byte(resultJSON), &response); err != nil { - traceLogger.Error(err, "Failed to unmarshal response") - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - return &response, nil + return C.GoString(cResult), nil } -// FetchChatTemplate fetches the model chat template using the cached Python function. -// -//nolint:gocritic // hugeParam: req is passed by value intentionally for immutability, but can consider using pointer. -func (w *ChatTemplatingProcessor) FetchChatTemplate( +// Encode RenderedString. +func (w *ChatTemplatingProcessor) Encode( ctx context.Context, - req FetchChatTemplateRequest, -) (string, map[string]interface{}, error) { - traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("FetchChatTemplate") - + req *EncodeRequest, +) ([]uint32, []tokenizers.Offset, error) { + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("Encode") // Convert request to JSON reqJSON, err := json.Marshal(req) if err != nil { traceLogger.Error(err, "Failed to marshal request") - return "", nil, fmt.Errorf("failed to marshal request: %w", err) + return nil, nil, fmt.Errorf("failed to marshal request: %w", err) } + + traceLogger.Info("Encoding text", "req", reqJSON) + // Call the cached Python function - cResult := C.Py_CallGetModelChatTemplate(C.CString(string(reqJSON))) + cResult := C.Py_CallEncode(C.CString(string(reqJSON))) if cResult == nil { traceLogger.Error(nil, "C function returned nil") - return "", nil, fmt.Errorf("python get_model_chat_template failed") + return nil, nil, fmt.Errorf("python encode failed") } defer C.free(unsafe.Pointer(cResult)) resultJSON := C.GoString(cResult) // Parse the response - var response FetchChatTemplateResponse + var response EncodeResponse if err := json.Unmarshal([]byte(resultJSON), &response); err != nil { traceLogger.Error(err, "Failed to unmarshal response") - return "", nil, fmt.Errorf("failed to unmarshal response: %w", err) + return nil, nil, fmt.Errorf("failed to unmarshal response: %w", err) } - return response.ChatTemplate, response.ChatTemplateKWArgs, nil + return response.TokenIDs, response.OffsetMappings, nil } // ClearCaches clears all caches for testing purposes. diff --git a/pkg/preprocessing/chat_completions/cgo_functions.h b/pkg/preprocessing/chat_completions/cgo_functions.h index 91b454c2f..82bc5a1ec 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.h +++ b/pkg/preprocessing/chat_completions/cgo_functions.h @@ -48,23 +48,23 @@ const char* PyUnicode_AsGoString(PyObject* obj); // Global variables to hold cached module and functions extern PyObject* g_chat_template_module; -extern PyObject* g_render_jinja_template_func; -extern PyObject* g_get_model_chat_template_func; +extern PyObject* g_apply_chat_template_func; +extern PyObject* g_encode_func; // Initialize the cached module and functions (call once at startup) int Py_InitChatTemplateModule(); -// Call the cached render_jinja_template function -char* Py_CallRenderJinjaTemplate(const char* json_request); +// Call the cached apply_chat_template function +char* Py_CallApplyChatTemplate(const char* json_request); // Internal function that does the actual work -char* Py_CallRenderJinjaTemplateInternal(const char* json_request); +char* Py_CallApplyChatTemplateInternal(const char* json_request); -// Call the cached get_model_chat_template function -char* Py_CallGetModelChatTemplate(const char* json_request); +// Call the cached encode function +char* Py_CallEncode(const char* json_request); // Internal function that does the actual work -char* Py_CallGetModelChatTemplateInternal(const char* json_request); +char* Py_CallEncodeInternal(const char* json_request); // Clear all caches for testing purposes char* Py_ClearCaches(void); @@ -75,4 +75,4 @@ void Py_CleanupChatTemplateModule(); // Re-initialize Python interpreter state int Py_ReinitializeGo(); -#endif // CGO_FUNCTIONS_H \ No newline at end of file +#endif // CGO_FUNCTIONS_H \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/cgo_functions_test.go b/pkg/preprocessing/chat_completions/cgo_functions_test.go index 19a9048f9..3467b17dc 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions_test.go +++ b/pkg/preprocessing/chat_completions/cgo_functions_test.go @@ -28,14 +28,15 @@ import ( preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" ) // Global singleton wrapper to prevent multiple Python interpreter initializations. var ( globalWrapper *preprocessing.ChatTemplatingProcessor globalWrapperOnce sync.Once - globalWrapperMu sync.Mutex ) // getGlobalWrapper returns a singleton wrapper instance. @@ -50,8 +51,8 @@ func getGlobalWrapper() *preprocessing.ChatTemplatingProcessor { return globalWrapper } -// TestGetModelChatTemplate tests the get_model_chat_template function. -func TestGetModelChatTemplate(t *testing.T) { +// TestEncode tests the encode function. +func TestEncode(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements @@ -59,61 +60,54 @@ func TestGetModelChatTemplate(t *testing.T) { require.NoError(t, err, "Failed to clear caches") tests := []struct { - name string - modelName string - revision string - token string - expectTemplate bool + name string + modelName string + revision string + hfToken string + expectToken uint32 }{ { - name: "IBM Granite Model", - modelName: "ibm-granite/granite-3.3-8b-instruct", - expectTemplate: true, + name: "IBM Granite Model", + modelName: "ibm-granite/granite-3.3-8b-instruct", + expectToken: 8279, }, { - name: "DialoGPT Model", - modelName: "microsoft/DialoGPT-medium", - expectTemplate: true, + name: "DialoGPT Model", + modelName: "microsoft/DialoGPT-medium", + expectToken: 15496, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := preprocessing.FetchChatTemplateRequest{ - Model: tt.modelName, - Revision: tt.revision, - Token: tt.token, + request := &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: tt.modelName, + Revision: tt.revision, + Token: tt.hfToken, + }, + Text: "Hello, how are you?", } // Profile the function call start := time.Now() - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + tokens, offsets, err := wrapper.Encode(context.Background(), request) duration := time.Since(start) // Log performance - t.Logf("Model: %s, Duration: %v, ChatTemplate length: %d", tt.modelName, duration, len(template)) - - if tt.expectTemplate { - // Models that should have templates - require.NoError(t, err, "FetchChatTemplate should not return an error") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") - assert.Contains(t, template, "messages", "ChatTemplate should contain messages") - } else { - // Models that don't have chat templates - if err != nil { - t.Logf("Expected error for model without chat template: %v", err) - } else { - // Some models might return empty template instead of error - t.Logf("Model returned empty template (expected for non-chat models)") - } - } + t.Logf("Model: %s, Duration: %v, Tokens length: %d", tt.modelName, duration, len(tokens)) + + // Models that should have templates + require.NoError(t, err, "Encode should not return an error") + assert.NotEmpty(t, tokens, "Tokens should not be empty") + assert.NotNil(t, offsets, "Offsets should not be nil") + assert.Contains(t, tokens, tt.expectToken, "Tokens should contain expected token") }) } } -// TestRenderJinjaTemplate tests the render_jinja_template function. -func TestRenderJinjaTemplate(t *testing.T) { +// TestApplyChatTemplate tests the apply_chat_template function. +func TestApplyChatTemplate(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements @@ -140,12 +134,12 @@ func TestRenderJinjaTemplate(t *testing.T) { tests := []struct { name string template string - messages []preprocessing.ChatMessage + messages []preprocessing.Conversation }{ { name: "Simple ChatTemplate", template: simpleTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, }, @@ -153,7 +147,7 @@ func TestRenderJinjaTemplate(t *testing.T) { { name: "Complex ChatTemplate with System Message", template: complexTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "system", Content: "You are a helpful AI assistant."}, {Role: "user", Content: "What is the weather like?"}, {Role: "assistant", Content: "I don't have access to real-time weather data."}, @@ -162,7 +156,7 @@ func TestRenderJinjaTemplate(t *testing.T) { { name: "Complex ChatTemplate without System Message", template: complexTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "user", Content: "Tell me a joke"}, {Role: "assistant", Content: "Why don't scientists trust atoms? Because they make up everything!"}, }, @@ -171,26 +165,28 @@ func TestRenderJinjaTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: tt.messages, - ChatTemplate: tt.template, + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: false, + Model: "facebook/opt-125m", + }, + Conversation: tt.messages, + ChatTemplate: tt.template, } // Profile the function call start := time.Now() - response, err := wrapper.RenderChatTemplate(context.Background(), request) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), request) duration := time.Since(start) // Assertions - require.NoError(t, err, "RenderChatTemplate should not return an error") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") + require.NoError(t, err, "ApplyChatTemplate should not return an error") + assert.NotNil(t, rendered, "Response should not be nil") + assert.NotEmpty(t, rendered, "Rendered chats should not be empty") // Log performance - t.Logf("ChatTemplate: %s, Duration: %v, Rendered length: %d", tt.name, duration, len(response.RenderedChats[0])) + t.Logf("ChatTemplate: %s, Duration: %v, Rendered length: %d", tt.name, duration, len(rendered)) - // Verify rendered content - rendered := response.RenderedChats[0] for _, message := range tt.messages { // For complex templates, the role might not be explicitly shown in output // but the content should always be present @@ -205,8 +201,8 @@ func TestRenderJinjaTemplate(t *testing.T) { } } -// TestTemplateCaching tests the caching functionality. -func TestTemplateCaching(t *testing.T) { +// TestTokenizerCaching tests the caching functionality. +func TestTokenizerCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear all caches to ensure we start with a clean state @@ -214,27 +210,30 @@ func TestTemplateCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") modelName := "ibm-granite/granite-3.3-8b-instruct" - request := preprocessing.FetchChatTemplateRequest{ - Model: modelName, + request := &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: modelName, + }, + Text: "What is the capital of France?", } // First call - should be cache miss t.Log("=== First call (Cache MISS) ===") start := time.Now() - template1, vars1, err := wrapper.FetchChatTemplate(context.Background(), request) + tokens1, offset1, err := wrapper.Encode(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - should be cache hit t.Log("=== Second call (Cache HIT) ===") start = time.Now() - template2, vars2, err := wrapper.FetchChatTemplate(context.Background(), request) + tokens2, offset2, err := wrapper.Encode(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") // Verify results are identical - assert.Equal(t, template1, template2, "Cached and non-cached results should be identical") - assert.Equal(t, vars1, vars2, "Cached and non-cached vars should be identical") + assert.Equal(t, tokens1, tokens2, "Cached and non-cached results should be identical") + assert.Equal(t, offset1, offset2, "Cached and non-cached vars should be identical") // Verify performance improvement t.Logf("First call duration: %v, Second call duration: %v, Speedup: %.1fx", @@ -255,13 +254,13 @@ func TestChatCompletionsIntegration(t *testing.T) { tests := []struct { name string modelName string - conversation []preprocessing.ChatMessage + conversation []preprocessing.Conversation description string }{ { name: "Simple Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the capital of France?"}, {Role: "assistant", Content: "The capital of France is Paris."}, }, @@ -270,7 +269,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "Multi-turn Conversation", modelName: "microsoft/DialoGPT-medium", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello, how are you?"}, {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, {Role: "user", Content: "Can you tell me about machine learning?"}, @@ -282,7 +281,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "System Message Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "system", Content: "You are a helpful AI assistant specialized in coding."}, {Role: "user", Content: "Write a Python function to calculate fibonacci numbers."}, {Role: "assistant", Content: "Here's a Python function to calculate fibonacci numbers:\n" + @@ -293,7 +292,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "Simple Conversation (Repeated)", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the capital of France?"}, {Role: "assistant", Content: "The capital of France is Paris."}, }, @@ -305,33 +304,25 @@ func TestChatCompletionsIntegration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Logf("Testing: %s - %s", tt.name, tt.description) - // Step 1: Get the model's chat template start := time.Now() - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: tt.modelName, - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) templateDuration := time.Since(start) - require.NoError(t, err, "Failed to get model chat template") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - // Step 2: Render the conversation using the template + // Step 1: Render the conversation using the template start = time.Now() - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: tt.conversation, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: false, + Model: tt.modelName, + }, + Conversation: tt.conversation, } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render chat template") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") - - // Step 3: Verify the rendered output - rendered := response.RenderedChats[0] - assert.NotEmpty(t, rendered, "Rendered chat should not be empty") + assert.NotNil(t, rendered, "Response should not be nil") + assert.NotEmpty(t, rendered, "Rendered chats should not be empty") + // Step 2: Verify the rendered output // Verify all conversation messages are present in the rendered output for _, message := range tt.conversation { assert.Contains(t, rendered, message.Content, "Rendered content should contain message content") @@ -379,7 +370,7 @@ func TestLongChatCompletions(t *testing.T) { require.NoError(t, err, "Failed to clear caches") // Create a long conversation - longConversation := []preprocessing.ChatMessage{ + longConversation := []preprocessing.Conversation{ {Role: "system", Content: "You are an expert software engineer with deep knowledge of Go, Python, " + "and system design. " + "Provide detailed, accurate responses."}, @@ -408,28 +399,23 @@ func TestLongChatCompletions(t *testing.T) { modelName := "ibm-granite/granite-3.3-8b-instruct" t.Run("Long Conversation Processing", func(t *testing.T) { - // Get template start := time.Now() - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: modelName, - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) templateDuration := time.Since(start) - require.NoError(t, err, "Failed to get model chat template") // Render long conversation start = time.Now() - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: longConversation, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: false, + Model: modelName, + }, + Conversation: longConversation, } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render long conversation") // Verify results - rendered := response.RenderedChats[0] assert.NotEmpty(t, rendered, "Long conversation should render successfully") assert.Greater(t, len(rendered), 1000, "Long conversation should produce substantial output") @@ -446,16 +432,19 @@ func TestLongChatCompletions(t *testing.T) { }) } -// BenchmarkGetModelChatTemplate benchmarks the template fetching performance. -func BenchmarkGetModelChatTemplate(b *testing.B) { +// BenchmarkEncode benchmarks the encode performance. +func BenchmarkEncode(b *testing.B) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements err := preprocessing.ClearCaches(context.Background()) require.NoError(b, err, "Failed to clear caches") - request := preprocessing.FetchChatTemplateRequest{ - Model: "ibm-granite/granite-3.3-8b-instruct", + request := &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: "ibm-granite/granite-3.3-8b-instruct", + }, + Text: "What is the capital of France?", } // Track first iteration time and total time @@ -465,7 +454,7 @@ func BenchmarkGetModelChatTemplate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - _, _, err := wrapper.FetchChatTemplate(context.Background(), request) + _, _, err := wrapper.Encode(context.Background(), request) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -489,28 +478,23 @@ func BenchmarkGetModelChatTemplate(b *testing.B) { b.ReportMetric(float64(warmAvg.Nanoseconds()), "ns/op_warm") } -// BenchmarkRenderJinjaTemplate benchmarks the template rendering performance. -func BenchmarkRenderJinjaTemplate(b *testing.B) { +// BenchmarkApplyChatTemplate benchmarks the template rendering performance. +func BenchmarkApplyChatTemplate(b *testing.B) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements err := preprocessing.ClearCaches(context.Background()) require.NoError(b, err, "Failed to clear caches") - // Get template first - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: "ibm-granite/granite-3.3-8b-instruct", - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - require.NoError(b, err, "Failed to get template for benchmark") - - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: false, + Model: "ibm-granite/granite-3.3-8b-instruct", + }, + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, }, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, } // Track first iteration time and total time @@ -520,7 +504,7 @@ func BenchmarkRenderJinjaTemplate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - _, err := wrapper.RenderChatTemplate(context.Background(), request) + _, err := wrapper.ApplyChatTemplate(context.Background(), request) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -596,8 +580,12 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { wrapper := getGlobalWrapper() // Test case based on the provided vLLM request - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: false, + Model: modelName, + }, + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the weather in Paris?"}, {Role: "assistant", Content: "Let me check that for you."}, }, @@ -614,27 +602,9 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { }, } - // Step 1: Get the chat template from the specified model - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: modelName, - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - require.NoError(t, err, "Failed to get chat template") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - - // Step 2: Update the request with the actual template and template variables - request.ChatTemplate = template - if templateVars != nil { - // Use the template variables from the model (contains special tokens like eos_token) - request.ChatTemplateKWArgs = templateVars - } - - // Step 3: Render the conversation with the template - response, err := wrapper.RenderChatTemplate(context.Background(), request) - require.NoError(t, err, "Failed to render chat template") - require.Len(t, response.RenderedChats, 1, "Should have one rendered chat") - - renderedOutput := response.RenderedChats[0] + // Render the conversation with the template + renderedOutput, err := wrapper.ApplyChatTemplate(context.Background(), request) + require.NoError(t, err, "Failed to apply chat template") // Step 4: Compare results with flexible date handling compareVLLMOutput(t, renderedOutput, expectedVLLMOutput) @@ -679,31 +649,34 @@ func compareVLLMOutput(t *testing.T, renderedOutput, expectedVLLMOutput string) t.Fail() // Mark test as failed } -// TestFetchChatTemplateLocalPath tests fetching chat templates from local paths. -func TestFetchChatTemplateLocalPath(t *testing.T) { +// TestEncodeLocalPath tests encode from local paths. +func TestEncodeLocalPath(t *testing.T) { wrapper := getGlobalWrapper() // Get the path to the test model tokenizer // The testdata directory is in pkg/tokenization/testdata testModelPath := "../../tokenization/testdata/test-model" - request := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, + request := &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testModelPath, + IsLocal: true, + }, + Text: "Hello from local tokenizer!", } - // Fetch the chat template - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + // Encode the text using the local tokenizer + tokens, offset, err := wrapper.Encode(context.Background(), request) // Assertions - require.NoError(t, err, "FetchChatTemplate should not return an error for local path") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") + require.NoError(t, err, "Encode should not return an error for local path") + assert.NotEmpty(t, tokens, "tokens should not be empty") + assert.NotNil(t, offset, "offset should not be nil") // Verify the template contains expected content - assert.Contains(t, template, "messages", "ChatTemplate should contain messages variable") - t.Logf("Fetched local template: %s", template) - t.Logf("Template vars: %+v", templateVars) + assert.Contains(t, tokens, uint32(7592), "tokens should contain 7592(hello)") + t.Logf("Fetched local template: %v", tokens) + t.Logf("Template vars: %+v", offset) } // TestRenderChatTemplateWithLocalTemplate tests rendering with a locally fetched template. @@ -713,40 +686,32 @@ func TestRenderChatTemplateWithLocalTemplate(t *testing.T) { // Get the path to the test model tokenizer testModelPath := "../../tokenization/testdata/test-model" - // First, fetch the template - fetchRequest := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, - } - - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), fetchRequest) - require.NoError(t, err, "FetchChatTemplate should not return an error") - // Now render a conversation using the fetched template - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: testModelPath, + }, + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello from local tokenizer!"}, {Role: "assistant", Content: "Hi! I'm using a locally loaded template."}, }, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) - require.NoError(t, err, "RenderChatTemplate should not return an error") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) + require.NoError(t, err, "ApplyChatTemplate should not return an error") + assert.NotNil(t, rendered, "Response should not be nil") + assert.NotEmpty(t, rendered, "Rendered chats should not be empty") // Verify the rendered content - rendered := response.RenderedChats[0] assert.Contains(t, rendered, "Hello from local tokenizer!", "Rendered content should contain user message") assert.Contains(t, rendered, "Hi! I'm using a locally loaded template.", "Rendered content should contain assistant message") t.Logf("Rendered chat with local template:\n%s", rendered) } -// TestFetchChatTemplateLocalPathCaching tests that local templates are cached properly. -func TestFetchChatTemplateLocalPathCaching(t *testing.T) { +// TestRenderChatTemplateLocalPathCaching tests that local templates are cached properly. +func TestRenderChatTemplateLocalPathCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches first @@ -754,26 +719,31 @@ func TestFetchChatTemplateLocalPathCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") testModelPath := "../../tokenization/testdata/test-model" - request := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testModelPath, + IsLocal: true, + }, + Conversation: []preprocessing.Conversation{ + {Role: "user", Content: "Hello from local tokenizer!"}, + {Role: "assistant", Content: "Hi! I'm using a locally loaded template."}, + }, } // First call - cache miss start := time.Now() - template1, vars1, err := wrapper.FetchChatTemplate(context.Background(), request) + rendered1, err := wrapper.ApplyChatTemplate(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - cache hit start = time.Now() - template2, vars2, err := wrapper.FetchChatTemplate(context.Background(), request) + rendered2, err := wrapper.ApplyChatTemplate(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") // Verify results are identical - assert.Equal(t, template1, template2, "Cached and non-cached templates should be identical") - assert.Equal(t, vars1, vars2, "Cached and non-cached vars should be identical") + assert.Equal(t, rendered1, rendered2, "Cached and non-cached templates should be identical") // Cache hit should be faster t.Logf("First call (cache miss): %v, Second call (cache hit): %v, Speedup: %.1fx", @@ -781,47 +751,55 @@ func TestFetchChatTemplateLocalPathCaching(t *testing.T) { assert.Less(t, duration2, duration1, "Cache hit should be faster than cache miss") } -// TestFetchChatTemplateLocalPathWithFile tests loading from a specific tokenizer.json file path. -func TestFetchChatTemplateLocalPathWithFile(t *testing.T) { +// TestRenderChatTemplateLocalPathWithFile tests loading from a specific tokenizer.json file path. +func TestRenderChatTemplateLocalPathWithFile(t *testing.T) { wrapper := getGlobalWrapper() // Test with the full path to tokenizer.json //nolint:gosec // This is a test file path, not a credential testTokenizerPath := "../../tokenization/testdata/test-model/tokenizer.json" - request := preprocessing.FetchChatTemplateRequest{ - Model: testTokenizerPath, - IsLocalPath: true, + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testTokenizerPath, + IsLocal: true, + }, + Conversation: []preprocessing.Conversation{ + {Role: "user", Content: "Hello from local tokenizer file!"}, + {Role: "assistant", Content: "Hi! I'm using a locally loaded template from file."}, + }, } - // Fetch the chat template - should extract directory and load from there - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + // Render the chat template - should extract directory and load from there + rendered, err := wrapper.ApplyChatTemplate(context.Background(), request) // Assertions - require.NoError(t, err, "FetchChatTemplate should handle file path and extract directory") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") - assert.Contains(t, template, "messages", "ChatTemplate should contain messages variable") + require.NoError(t, err, "ApplyChatTemplate should handle file path and extract directory") + assert.NotEmpty(t, rendered, "ChatTemplate should not be empty") - t.Logf("Fetched template from file path: %s", template) + t.Logf("Fetched template from file path: %s", rendered) } -// TestFetchChatTemplateLocalPathNonExistent tests error handling for non-existent local paths. -func TestFetchChatTemplateLocalPathNonExistent(t *testing.T) { +// TestRenderChatTemplateLocalPathNonExistent tests error handling for non-existent local paths. +func TestRenderChatTemplateLocalPathNonExistent(t *testing.T) { wrapper := getGlobalWrapper() - request := preprocessing.FetchChatTemplateRequest{ - Model: "/non/existent/path", - IsLocalPath: true, + request := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: "/non/existent/path", + IsLocal: true, + }, + Conversation: []preprocessing.Conversation{ + {Role: "user", Content: "This should fail."}, + }, } // This should return an error - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), request) // Assertions - assert.Error(t, err, "FetchChatTemplate should return an error for non-existent path") - assert.Empty(t, template, "ChatTemplate should be empty on error") - assert.Nil(t, templateVars, "ChatTemplate vars should be nil on error") + assert.Error(t, err, "ApplyChatTemplate should return an error for non-existent path") + assert.Empty(t, rendered, "Rendered should be empty on error") t.Logf("Expected error for non-existent path: %v", err) } @@ -830,6 +808,7 @@ func TestFetchChatTemplateLocalPathNonExistent(t *testing.T) { func TestMain(m *testing.M) { // Create a new processor to handle initialization. processor := preprocessing.NewChatTemplatingProcessor() + ctrl.SetLogger(zap.New(zap.UseDevMode(true))) // Set up: Initialize the Python interpreter. log.Log.Info("Initializing Python interpreter for tests...") diff --git a/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py b/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py deleted file mode 100644 index b87fbbefd..000000000 --- a/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2025 The llm-d Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env python3 -""" -Standalone wrapper for render_jinja_template function from transformers. -""" - -import json -import logging -import sys -from typing import Optional, Union - -# Import core functions from transformers - moved to function level to avoid import errors -TRANSFORMERS_AVAILABLE = None # Will be set when first needed - -def _ensure_transformers_available(): - """Ensure transformers is available, importing it if needed.""" - global TRANSFORMERS_AVAILABLE - if TRANSFORMERS_AVAILABLE is None: - try: - print("[Python] Attempting to import transformers...") - from transformers.utils.chat_template_utils import render_jinja_template as transformers_render_jinja_template, get_json_schema - from transformers import AutoTokenizer - print("[Python] Successfully imported transformers!") - TRANSFORMERS_AVAILABLE = True - return True - except ImportError as e: - print(f"[Python] Failed to import transformers: {e}") - print("[Python] Ensure the 'transformers' library is installed in the Python environment.") - TRANSFORMERS_AVAILABLE = False - return False - return TRANSFORMERS_AVAILABLE - -# Basic logging setup -logger = logging.getLogger(__name__) - -# Module-level cache for templates -_template_cache = {} -_cache_lock = None - -def _get_cache_lock(): - """Get or create a threading lock for cache access.""" - global _cache_lock - if _cache_lock is None: - import threading - _cache_lock = threading.Lock() - return _cache_lock - - -def _collect_template_vars(tokenizer): - """Collect extra rendering variables from a tokenizer.""" - kwargs = {} - for k in ["bos_token", "eos_token", "eot_token", "pad_token", "unk_token", "sep_token", "additional_special_tokens"]: - v = getattr(tokenizer, k, None) - if v is not None: - kwargs[k] = v - return kwargs - - -def clear_caches(): - """Clear all caches for testing purposes.""" - lock = _get_cache_lock() - with lock: - global _template_cache - _template_cache.clear() - return "Caches cleared" - - -def render_jinja_template(request_json): - """ - Render a chat template using the transformers library. - This function is aligned with the Go cgo_functions.go structs. - - Args: - request_json (str): JSON string containing the request parameters: - - conversations (list): List of conversation lists - - chat_template (str, optional): The template to use - - tools (list, optional): Tool schemas - - documents (list, optional): Document schemas - - return_assistant_tokens_mask (bool, optional): Whether to return assistant tokens mask - - continue_final_message (bool, optional): Whether to continue final message - - add_generation_prompt (bool, optional): Whether to add generation prompt - - kwargs (dict, optional): Additional rendering variables - Returns: - str: JSON string containing 'rendered_chats' and 'generation_indices' keys. - """ - if not _ensure_transformers_available(): - raise ImportError("transformers library is required for render_jinja_template") - - # Import the modules we need - from transformers.utils.chat_template_utils import render_jinja_template as transformers_render_jinja_template - - # Parse the JSON request - request = json.loads(request_json) - - # Align Go's `messages` field with transformers' `conversations` parameter. - if 'messages' in request: - request['conversations'] = [request.pop('messages')] # wrap to match expected format - - try: - # Get template_vars and spread them as individual arguments - template_vars = request.pop('chat_template_kwargs', {}) - request.update(template_vars) - - rendered_chats, generation_indices = transformers_render_jinja_template(**request) - - except Exception as e: - raise - - # Return as JSON string, aligning with the Go response struct. - result = json.dumps({ - "rendered_chats": rendered_chats, - "generation_indices": generation_indices - }) - return result - - -def get_model_chat_template(request_json): - """ - Load a tokenizer from Hugging Face Hub or local path and return its chat template string and required variables. - Args: - request_json (str): JSON string containing the request parameters: - - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). - - chat_template (str, optional): The template name or string to use. - - tools (list[dict], optional): Tool schemas to pass. - - revision (str, optional): Model revision. - - token (str, optional): Hugging Face token for private models. - - is_local_path (bool, optional): Whether the model is a local path (default: False). - Returns: - str: JSON string containing 'template' and 'kwargs' keys, aligning with the Go response struct. - """ - if not _ensure_transformers_available(): - print("[Python] get_model_chat_template ERROR - Transformers not available") - raise ImportError("transformers library is required for get_model_chat_template") - - # Parse the JSON request - request = json.loads(request_json) - - model_name = request.get("model") - chat_template = request.get("chat_template") - tools = request.get("tools") - revision = request.get("revision") - token = request.get("token") - is_local_path = request.get("is_local_path", False) - - if not model_name: - print("[Python] get_model_chat_template ERROR - model_name is required") - raise ValueError("model_name is required in request") - - # Create cache key - cache_key = f"{model_name}:{revision or 'main'}:{token or 'none'}:{is_local_path}" - - # Check cache first - lock = _get_cache_lock() - with lock: - if cache_key in _template_cache: - cached_result = _template_cache[cache_key] - # If a specific chat_template was requested, override the cached template - if chat_template is not None: - cached_result["template"] = chat_template - return json.dumps(cached_result) - - # Import the modules we need - from transformers import AutoTokenizer - import os - - # Determine if we're loading from local path or HuggingFace - if is_local_path: - # For local paths, model_name can be either a directory containing tokenizer files - # or a path to a specific tokenizer file. Ensure we extract the directory if needed. - if os.path.isfile(model_name): - # If it's a file path (tokenizer.json), get the directory - tokenizer_dir = os.path.dirname(model_name) - else: - # If it's already a directory, use it directly - tokenizer_dir = model_name - - print(f"[Python] Loading tokenizer from local path: {tokenizer_dir}") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True, trust_remote_code=True) - else: - # Load from Hugging Face - print(f"[Python] Loading tokenizer from HuggingFace: {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name, revision=revision, token=token, trust_remote_code=True) - - template = tokenizer.chat_template if chat_template is None else chat_template - - # Collect special tokens - template_vars = _collect_template_vars(tokenizer) - - # Cache the result, aligning with the Go response struct. - result = {"chat_template": template, "chat_template_kwargs": template_vars} - with lock: - _template_cache[cache_key] = result.copy() # Cache a copy to avoid reference issues - - return json.dumps(result) - - -def main(): - """Example usage and testing function.""" - if not _ensure_transformers_available(): - print("Error: transformers library is required but not available") - print("Please install transformers: pip install transformers") - return - - if len(sys.argv) < 2: - print("Usage: python render_jinja_template_wrapper.py [conversation_json]") - print("Example:") - print('python render_jinja_template_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"') - return - - chat_template = sys.argv[1] - - # Default conversation if none provided - conversation = [ - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there! How can I help you today?"} - ] - - if len(sys.argv) > 2: - try: - conversation = json.loads(sys.argv[2]) - except json.JSONDecodeError: - print("Error: Invalid JSON for conversation") - return - - try: - # Construct the request JSON string similar to how Go would - request_str = json.dumps({ - "conversations": [conversation], - "chat_template": chat_template - }) - response_str = render_jinja_template(request_str) - response_data = json.loads(response_str) - - rendered = response_data['rendered_chats'] - generation_indices = response_data['generation_indices'] - - print("Rendered chat:") - print(rendered[0]) - if generation_indices and len(generation_indices) > 0 and generation_indices[0]: - print(f"Generation indices: {generation_indices[0]}") - except Exception as e: - print(f"Error: {e}") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/requirements.txt b/pkg/preprocessing/chat_completions/requirements.txt index 649c31de8..88f4bf039 100644 --- a/pkg/preprocessing/chat_completions/requirements.txt +++ b/pkg/preprocessing/chat_completions/requirements.txt @@ -1,7 +1 @@ ---extra-index-url https://download.pytorch.org/whl/cpu - -packaging==24.2 -pillow==11.2.1 -torch==2.5.1 -transformers>=4.53.0,<4.57.2 -jinja2>=2.11 +vllm>=0.11.0 \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py new file mode 100644 index 000000000..7496f3816 --- /dev/null +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -0,0 +1,195 @@ +# Copyright 2025 The llm-d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Standalone wrapper for tokenizer from vllm. +""" + +import json +import logging +import os +import sys +from typing import Optional, Union +from vllm.transformers_utils.tokenizer import get_tokenizer + +# Basic logging setup +logger = logging.getLogger(__name__) + +# Module-level cache for templates +_tokenizer_cache = {} +_tokenizer_cache_lock = None + +def _get_tokenizer_cache_lock(): + """Get or create a threading lock for tokenizer cache access.""" + global _tokenizer_cache_lock + if _tokenizer_cache_lock is None: + import threading + _tokenizer_cache_lock = threading.RLock() + return _tokenizer_cache_lock + +def clear_caches(): + """Clear the tokenizer cache for testing purposes.""" + lock = _get_tokenizer_cache_lock() + with lock: + global _tokenizer_cache + _tokenizer_cache.clear() + return "Tokenizer caches cleared" + +def apply_chat_template(request_json): + """ + Render a chat template using the transformers library. + This function is aligned with the Go cgo_functions.go structs. + + Args: + request_json (str): JSON string containing the request parameters: + - is_local (bool, optional): Whether the model is local. + - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). + - revision (str, optional): Model revision. + - token (str, optional): Hugging Face token for private models. + - conversations (list): List of conversation lists + - chat_template (str, optional): The template to use + - tools (list, optional): Tool schemas + - documents (list, optional): Document schemas + - return_assistant_tokens_mask (bool, optional): Whether to return assistant tokens mask + - continue_final_message (bool, optional): Whether to continue final message + - add_generation_prompt (bool, optional): Whether to add generation prompt + - kwargs (dict, optional): Additional rendering variables + - isVLLM / - isSGLang enum + + Returns: + str: JSON string containing 'rendered_chats' and 'generation_indices' keys. + """ + + try: + # Parse the JSON request + request = json.loads(request_json) + + # Get template_vars and spread them as individual arguments + template_vars = request.pop('chat_template_kwargs', {}) + request.update(template_vars) + + model_name = request.pop("model") + revision = request.get("revision", None) + is_local = request.pop("is_local", False) + token = request.pop("token", "") + download_dir = request.pop("download_dir", None) + + if is_local and os.path.isfile(model_name): + # If it's a file path (tokenizer.json), get the directory + model_name = os.path.dirname(model_name) + + lock = _get_tokenizer_cache_lock() + with lock: + cache_key = f"{model_name}:{revision or 'main'}:{is_local}" + tokenizer = _tokenizer_cache.get(cache_key) + if tokenizer is None: + os.environ["HF_TOKEN"] = token + tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir) + _tokenizer_cache[cache_key] = tokenizer + + request["tokenize"] = False + return tokenizer.apply_chat_template(**request) + + except Exception as e: + raise RuntimeError(f"Error applying chat template: {e}") from e + + +def encode(request_json: str) -> str: + """ + Encode text using the specified tokenizer. + + Args: + request_json (str): JSON string containing: + - model (str): The model ID or path. + - revision (str, optional): Model revision. + - is_local (bool, optional): Whether the model is local. + - download_dir (str, optional): Directory to download the model. + - text (str): The text to encode. + - token (str, optional): Hugging Face token for private models. + - isVLLM (bool, optional): Whether to use VLLM tokenizer. + - isSGLang (bool, optional): Whether to use SG-Lang tokenizer. + - .... + + Returns: + str: JSON string containing 'encoded_texts' key with list of token ID lists. + """ + try: + request = json.loads(request_json) + model_name = request["model"] + revision = request.get("revision", None) + is_local = request.get("is_local", False) + download_dir = request.pop("download_dir", None) + text = request["text"] + token = request.get("token", "") + add_special_tokens = request.get("add_special_tokens", False) + + if is_local and os.path.isfile(model_name): + # If it's a file path (tokenizer.json), get the directory + model_name = os.path.dirname(model_name) + + lock = _get_tokenizer_cache_lock() + with lock: + cache_key = f"{model_name}:{revision or 'main'}:{is_local}" + tokenizer = _tokenizer_cache.get(cache_key) + if tokenizer is None: + os.environ["HF_TOKEN"] = token + tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir) + _tokenizer_cache[cache_key] = tokenizer + + return json.dumps(tokenizer(text, return_offsets_mapping=True, add_special_tokens=add_special_tokens).data) + + except Exception as e: + raise RuntimeError(f"Error encoding texts: {e}") from e + +def main(): + """Example usage and testing function.""" + + if len(sys.argv) < 2: + print("Usage: python tokenizer_wrapper.py [conversation_json]") + print("Example:") + print('python tokenizer_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"') + return + + chat_template = sys.argv[1] + + # Default conversation if none provided + conversation = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there! How can I help you today?"} + ] + + if len(sys.argv) > 2: + try: + conversation = json.loads(sys.argv[2]) + except json.JSONDecodeError: + print("Error: Invalid JSON for conversation") + return + + try: + # Construct the request JSON string similar to how Go would + request_str = json.dumps({ + "model": "facebook/opt-125m", + "conversation": [conversation], + "chat_template": chat_template + }) + response = apply_chat_template(request_str) + + print("Rendered chat:") + print(response[0]) + except Exception as e: + print(f"Error: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index c67eb7aa8..f93bd9041 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -35,8 +35,6 @@ const ( // Config holds the configuration for the TokenizationPool. type Config struct { - // Base model name for the tokenizer. - ModelName string `json:"modelName"` // Number of worker goroutines for processing tokenization tasks. WorkersCount int `json:"workersCount"` // Minimum overlap ratio to skip full tokenization and use cached prefix tokens. @@ -69,7 +67,7 @@ type tokenizationResponse struct { // Task represents a unit of work for tokenizing a prompt. type Task struct { - RenderReq *preprocessing.RenderJinjaTemplateRequest + RenderReq *preprocessing.ApplyChatTemplateRequest Prompt string ModelName string ResultCh chan<- tokenizationResponse // nil => fire-and-forget @@ -77,11 +75,10 @@ type Task struct { // Pool encapsulates the queue, worker pool, and token indexer. type Pool struct { - modelName string // base model name for tokenization - workers int - queue workqueue.TypedRateLimitingInterface[Task] - wg sync.WaitGroup - indexer prefixstore.Indexer + workers int + queue workqueue.TypedRateLimitingInterface[Task] + wg sync.WaitGroup + indexer prefixstore.Indexer // Tokenizer is configured for the specific model this pool handles. // It's shared between all pool workers. Since the tokenizer @@ -95,8 +92,8 @@ type Pool struct { // NewTokenizationPool initializes a TokenizationPool with the specified number // of workers and the provided Indexer. func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, error) { - if config == nil || config.ModelName == "" { - return nil, fmt.Errorf("config and config.ModelName cannot be nil or empty") + if config == nil { + return nil, fmt.Errorf("config cannot be nil or empty") } if !config.LocalTokenizerConfig.IsEnabled() && @@ -108,7 +105,7 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro tokenizers := make([]Tokenizer, 0, 3) if config.LocalTokenizerConfig.IsEnabled() { - localTokenizer, err := NewCachedLocalTokenizer(config.ModelName, *config.LocalTokenizerConfig) + localTokenizer, err := NewCachedLocalTokenizer(config.LocalTokenizerConfig) if err != nil { return nil, fmt.Errorf("failed to create local tokenizer: %w", err) } @@ -124,7 +121,7 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro } if config.HFTokenizerConfig.IsEnabled() { - hfTokenizer, err := NewCachedHFTokenizer(config.ModelName, config.HFTokenizerConfig) + hfTokenizer, err := NewCachedHFTokenizer(config.HFTokenizerConfig) if err != nil { return nil, fmt.Errorf("failed to create HuggingFace tokenizer: %w", err) } @@ -132,7 +129,6 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro } return &Pool{ - modelName: config.ModelName, workers: config.WorkersCount, queue: workqueue.NewTypedRateLimitingQueue(workqueue.DefaultTypedControllerRateLimiter[Task]()), indexer: store, @@ -143,19 +139,21 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro // EnqueueTokenization enqueues a new tokenization task. // This method only enqueues the task and does not start processing it. -func (pool *Pool) EnqueueTokenization(prompt string) { +func (pool *Pool) EnqueueTokenization(prompt, modelName string) { task := Task{ - Prompt: prompt, + Prompt: prompt, + ModelName: modelName, } pool.queue.Add(task) } // Tokenize queues a task and blocks until the final result is available. -func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt string) []uint32 { +func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt, modelName string) []uint32 { resultCh := make(chan tokenizationResponse, 1) pool.queue.Add(Task{ RenderReq: renderReq, Prompt: prompt, + ModelName: modelName, ResultCh: resultCh, }) @@ -200,22 +198,37 @@ func (pool *Pool) workerLoop(_ int) { // processTask tokenizes the prompt and updates the indexer. // It sends exactly one response (success or error) if ResultCh is provided. func (pool *Pool) processTask(task Task) error { + addSpecialTokens := true if task.RenderReq != nil { + task.RenderReq.Model = task.ModelName var err error - task.Prompt, err = pool.tokenizer.RenderChatTemplate(pool.modelName, task.RenderReq) + task.Prompt, err = pool.tokenizer.ApplyChatTemplate(task.ModelName, task.RenderReq) if err != nil { - log.Log.Error(err, "failed to render chat template") + log.Log.Error(err, "failed to render chat template", "modelName", task.ModelName) return err } + addSpecialTokens = false } tokenIDs, overlapRatio := pool.indexer.FindLongestContainedTokens(task.Prompt) // if the overlap ratio is low, get the full tokenization if overlapRatio < pool.minPrefixOverlapRatio { - tokens, offsets, err := pool.tokenizer.Encode(task.Prompt, pool.modelName) + tokens, offsets, err := pool.tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: task.ModelName, + Revision: func() string { + if task.RenderReq == nil { + return "" + } + return task.RenderReq.Revision + }(), + }, + Text: task.Prompt, + AddSpecialTokens: addSpecialTokens, + }) if err != nil { - log.Log.Error(err, "failed to encode tokens", "prompt", task.Prompt) + log.Log.Error(err, "failed to encode tokens", "prompt", task.Prompt, "modelName", task.ModelName) return err } @@ -239,8 +252,3 @@ func (pool *Pool) processTask(task Task) error { return nil } - -func (pool *Pool) SetTokenizer(tokenizer Tokenizer, modelName string) { - pool.tokenizer = tokenizer - pool.modelName = modelName -} diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index 3e680a118..88038cd2a 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -50,16 +50,32 @@ type MockTokenizer struct { mock.Mock } -func (m *MockTokenizer) RenderChatTemplate( - prompt string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (m *MockTokenizer) ApplyChatTemplate( + prompt string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { args := m.Called(prompt, renderReq) return args.String(0), args.Error(1) } -func (m *MockTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { - args := m.Called(input, modelName) - return args.Get(0).([]uint32), args.Get(1).([]tokenizers.Offset), args.Error(2) //nolint:errcheck // return mocked values +func (m *MockTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { + args := m.Called(req) + tokenIface := args.Get(0) + if tokenIface == nil { + return nil, nil, args.Error(2) + } + tokens, ok := tokenIface.([]uint32) + if !ok { + panic("MockTokenizer.Encode: expected []uint32 from mock, got unexpected type") + } + offsetIface := args.Get(1) + if offsetIface == nil { + return nil, nil, args.Error(2) + } + offsets, ok := offsetIface.([]tokenizers.Offset) + if !ok { + panic("MockTokenizer.Encode: expected []tokenizers.Offset from mock, got unexpected type") + } + return tokens, offsets, args.Error(2) } func (m *MockTokenizer) Type() string { @@ -79,7 +95,11 @@ func (m *MockIndexer) AddTokenization(prompt string, tokens []uint32, offsets [] //nolint:gocritic // unnamedResult: tokens and overlapRatio are self-explanatory from context func (m *MockIndexer) FindLongestContainedTokens(prompt string) ([]uint32, float64) { args := m.Called(prompt) - tokens := args.Get(0).([]uint32) //nolint:errcheck // unused mock + tokensIface := args.Get(0) + tokens, ok := tokensIface.([]uint32) + if !ok { + panic("MockIndexer.FindLongestContainedTokens: expected []uint32 from mock, got unexpected type") + } return tokens, 0.0 } @@ -88,7 +108,6 @@ func TestPool_ProcessTask(t *testing.T) { mockTokenizer := &MockTokenizer{} pool := &Pool{ - modelName: testModelName, workers: 1, indexer: mockIndexer, tokenizer: mockTokenizer, @@ -96,7 +115,8 @@ func TestPool_ProcessTask(t *testing.T) { } task := Task{ - Prompt: "hello world", + Prompt: "hello world", + ModelName: testModelName, } // Setup specific mock return values @@ -106,7 +126,13 @@ func TestPool_ProcessTask(t *testing.T) { // Mock FindLongestContainedTokens to return low overlap ratio mockIndexer.On("FindLongestContainedTokens", task.Prompt).Return([]uint32{}, 0.0) - mockTokenizer.On("Encode", task.Prompt, testModelName).Return(expectedTokens, expectedOffsets, nil) + mockTokenizer.On("Encode", &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: task.ModelName, + }, + Text: task.Prompt, + AddSpecialTokens: true, + }).Return(expectedTokens, expectedOffsets, nil) // Verify that indexer receives exactly the same tokens and offsets that tokenizer returned mockIndexer.On("AddTokenization", task.Prompt, expectedTokens, expectedOffsets).Return(nil) @@ -137,7 +163,6 @@ func TestPool_RunIntegration(t *testing.T) { } config := &Config{ - ModelName: testModelName, WorkersCount: 5, HFTokenizerConfig: DefaultHFTokenizerConfig(), MinPrefixOverlapRatio: defaultMinPrefixOverlapRatio, @@ -151,7 +176,7 @@ func TestPool_RunIntegration(t *testing.T) { defer cancel() for _, prompt := range prompts { - pool.EnqueueTokenization(prompt) + pool.EnqueueTokenization(prompt, testModelName) } // Run pool @@ -183,11 +208,10 @@ func generateRandomSentence(wordLength, maxWords int, rng *rand.Rand) string { return strings.Join(words, " ") } -func setupStressTest(b *testing.B, modelName string) *Pool { +func setupStressTest(b *testing.B) *Pool { b.Helper() config := &Config{ - ModelName: modelName, WorkersCount: benchmarkWorkerCount, HFTokenizerConfig: DefaultHFTokenizerConfig(), MinPrefixOverlapRatio: defaultMinPrefixOverlapRatio, @@ -208,15 +232,16 @@ func BenchmarkAsyncTokenizationStress(b *testing.B) { for _, modelName := range benchmarkModels { b.Run(modelName, func(b *testing.B) { - pool := setupStressTest(b, modelName) + pool := setupStressTest(b) // Return RNG for on-demand prompt generation rng := rand.New(rand.NewSource(benchmarkSeed)) //nolint:gosec // Test code - weak random is acceptable // Generate and enqueue prompts on-the-fly to avoid memory bloat - for range b.N { + for i := range b.N { prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng) - pool.EnqueueTokenization(prompt) + modelName := benchmarkModels[i%len(benchmarkModels)] + pool.EnqueueTokenization(prompt, modelName) } // Create context for the pool @@ -249,7 +274,7 @@ func BenchmarkSyncTokenizationStress(b *testing.B) { for _, modelName := range benchmarkModels { b.Run(modelName, func(b *testing.B) { - pool := setupStressTest(b, modelName) + pool := setupStressTest(b) // Return RNG for on-demand prompt generation rng := rand.New(rand.NewSource(benchmarkSeed)) //nolint:gosec // Test code - weak random is acceptable @@ -266,7 +291,8 @@ func BenchmarkSyncTokenizationStress(b *testing.B) { // Submit tokenization requests in a loop until limit for i := 0; b.Loop(); i++ { prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng) - pool.Tokenize(nil, prompt) + model := benchmarkModels[i%len(benchmarkModels)] + pool.Tokenize(nil, prompt, model) } b.StopTimer() diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 9339de27a..357c13c4a 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -34,9 +34,9 @@ import ( // Tokenizer interface defines the methods for tokenization. type Tokenizer interface { - RenderChatTemplate(string, *preprocessing.RenderJinjaTemplateRequest) (string, error) + ApplyChatTemplate(string, *preprocessing.ApplyChatTemplateRequest) (string, error) // Encode tokenizes the input string and returns the token IDs and offsets. - Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) + Encode(*preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) Type() string } @@ -256,40 +256,33 @@ func parseHFCacheModelName(dirName string) (string, bool) { return strings.Join(parts, "/"), true } -type tokenizerProvider interface { - get(modelName string) (*tokenizers.Tokenizer, error) - - getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) -} - -// CachedTokenizer implements the Tokenizer interface for a specific model. -// It holds a single tokenizer instance that is initialized at creation time -// for the target model, providing efficient tokenization without caching overhead. type CachedTokenizer struct { - tokenizer *tokenizers.Tokenizer - tokenizerProvider tokenizerProvider chatTemplateRenderer *preprocessing.ChatTemplatingProcessor } +type HFCachedTokenizer struct { + CachedTokenizer + hfTokenizerConfig *HFTokenizerConfig +} +type LocalCachedTokenizer struct { + CachedTokenizer + localTokenizerConfig *LocalTokenizerConfig +} + // NewCachedHFTokenizer creates a new instance of CachedTokenizer downloading tokenizer configs from HuggingFace with // the provided configuration. -func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (Tokenizer, error) { - tokenizerProvider := newHFTokenizerProvider(config) - tokenizer, err := tokenizerProvider.get(modelID) - if err != nil { - return nil, fmt.Errorf("failed to get tokenizer for model %q: %w", modelID, err) - } - +func NewCachedHFTokenizer(config *HFTokenizerConfig) (*HFCachedTokenizer, error) { chatTemplateRenderer := preprocessing.NewChatTemplatingProcessor() - err = chatTemplateRenderer.Initialize() + err := chatTemplateRenderer.Initialize() if err != nil { return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - return &CachedTokenizer{ - tokenizer: tokenizer, - tokenizerProvider: tokenizerProvider, - chatTemplateRenderer: chatTemplateRenderer, + return &HFCachedTokenizer{ + CachedTokenizer: CachedTokenizer{ + chatTemplateRenderer: chatTemplateRenderer, + }, + hfTokenizerConfig: config, }, nil } @@ -301,69 +294,92 @@ func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (Tokenizer, // - Pre-loaded models in containerized deployments // - Reducing startup latency by avoiding downloads // -// The tokenizer is initialized for a specific model at creation time. -func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (Tokenizer, error) { - if err := discoverLocalTokenizerMap(&config); err != nil { +// The tokenizer uses an LRU cache to keep frequently used tokenizers in memory, +// avoiding repeated file I/O for the same models. +func NewCachedLocalTokenizer(config *LocalTokenizerConfig) (*LocalCachedTokenizer, error) { + if err := discoverLocalTokenizerMap(config); err != nil { return nil, fmt.Errorf("failed to discover local tokenizer map: %w", err) } - tokenizerProvider := &localTokenizerProvider{ - cfg: config, - } - tokenizer, err := tokenizerProvider.get(modelName) - if err != nil { - return nil, fmt.Errorf("failed to get tokenizer for model map: %w", err) - } - chatTemplater := preprocessing.NewChatTemplatingProcessor() - err = chatTemplater.Initialize() + err := chatTemplater.Initialize() if err != nil { return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - return &CachedTokenizer{ - tokenizer: tokenizer, - tokenizerProvider: tokenizerProvider, - chatTemplateRenderer: chatTemplater, + return &LocalCachedTokenizer{ + CachedTokenizer: CachedTokenizer{ + chatTemplateRenderer: chatTemplater, + }, + localTokenizerConfig: config, }, nil } -func (t *CachedTokenizer) RenderChatTemplate( - modelName string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (t *HFCachedTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, ) (string, error) { ctx := context.TODO() - if renderReq.ChatTemplate == "" { - req, err := t.tokenizerProvider.getFetchChatTemplateRequest(modelName) - if err != nil { - return "", fmt.Errorf("failed to create fetch chat template request: %w", err) - } - renderReq.ChatTemplate, renderReq.ChatTemplateKWArgs, err = t.chatTemplateRenderer.FetchChatTemplate( - ctx, req, - ) - if err != nil { - return "", fmt.Errorf("failed to fetch chat template: %w", err) - } + req.IsLocal = false + req.DownloadDir = t.hfTokenizerConfig.TokenizersCacheDir + req.Token = t.hfTokenizerConfig.HuggingFaceToken + res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to render chat template: %w", err) } - res, err := t.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq) + return res, nil +} + +func (t *LocalCachedTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, +) (string, error) { + ctx := context.TODO() + + req.IsLocal = true + path, ok := t.localTokenizerConfig.ModelTokenizerMap[req.Model] + if !ok { + return "", fmt.Errorf("tokenizer for model %q not found", modelName) + } + req.Model = filepath.Dir(path) + res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) if err != nil { return "", fmt.Errorf("failed to render chat template: %w", err) } - return res.RenderedChats[0], nil + return res, nil } // Encode converts a string into token IDs. -// The modelName parameter is ignored since this tokenizer is bound to a specific model. -func (t *CachedTokenizer) Encode(input, _ string) ([]uint32, []tokenizers.Offset, error) { - encodeOptions := []tokenizers.EncodeOption{ - tokenizers.WithReturnTypeIDs(), - tokenizers.WithReturnOffsets(), +func (t *HFCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { + ctx := context.TODO() + + req.IsLocal = false + req.DownloadDir = t.hfTokenizerConfig.TokenizersCacheDir + req.Token = t.hfTokenizerConfig.HuggingFaceToken + tokens, offsets, err := t.chatTemplateRenderer.Encode(ctx, req) + if err != nil { + return nil, nil, fmt.Errorf("failed to encode: %w", err) } - resp := t.tokenizer.EncodeWithOptions(input, false, encodeOptions...) - return resp.IDs, resp.Offsets, nil + return tokens, offsets, nil +} + +func (t *LocalCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { + ctx := context.TODO() + + req.IsLocal = true + path, ok := t.localTokenizerConfig.ModelTokenizerMap[req.Model] + if !ok { + return nil, nil, fmt.Errorf("tokenizer for model %q not found", req.Model) + } + req.Model = filepath.Dir(path) + tokens, offsets, err := t.chatTemplateRenderer.Encode(ctx, req) + if err != nil { + return nil, nil, fmt.Errorf("failed to encode: %w", err) + } + + return tokens, offsets, nil } func (t *CachedTokenizer) Type() string { @@ -381,75 +397,6 @@ func getTokenizerCacheDir() string { return filepath.Join(base, "..", "..", "bin") } -// hfTokenizerProvider implements tokenizerProvider by downloading tokenizers from HuggingFace. -// It uses the HuggingFace tokenizers library to fetch tokenizer configurations from the HuggingFace Hub. -type hfTokenizerProvider struct { - cfgOpt tokenizers.TokenizerConfigOption - authToken string -} - -// newHFTokenizerProvider creates a new hfTokenizerProvider with the given configuration. -func newHFTokenizerProvider(config *HFTokenizerConfig) *hfTokenizerProvider { - var cfg tokenizers.TokenizerConfigOption - - if config != nil && config.TokenizersCacheDir != "" { - cfg = tokenizers.WithCacheDir(config.TokenizersCacheDir) - } - if config != nil && config.HuggingFaceToken != "" { - cfg = tokenizers.WithAuthToken(config.HuggingFaceToken) - } - - return &hfTokenizerProvider{ - cfgOpt: cfg, - authToken: config.HuggingFaceToken, - } -} - -// getTokenizer downloads and returns a tokenizer from HuggingFace for the specified model. -// The tokenizer is downloaded from https://huggingface.co/{modelName}. -func (p *hfTokenizerProvider) get(modelName string) (*tokenizers.Tokenizer, error) { - return tokenizers.FromPretrained(modelName, p.cfgOpt) -} - -func (p *hfTokenizerProvider) getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) { - return preprocessing.FetchChatTemplateRequest{ - Model: modelName, - Token: p.authToken, - IsLocalPath: false, - }, nil -} - -// localTokenizerProvider implements tokenizerProvider by loading tokenizers from local files. -// It looks up the tokenizer file path in the configuration mapping and loads it from disk. -type localTokenizerProvider struct { - cfg LocalTokenizerConfig -} - -// getTokenizer loads and returns a tokenizer from a local file for the specified model. -// It looks up the file path in the config mapping and loads the tokenizer file. -// Returns an error if the model name is not found in the mapping. -func (p *localTokenizerProvider) get(modelName string) (*tokenizers.Tokenizer, error) { - path, ok := p.cfg.ModelTokenizerMap[modelName] - if !ok { - return nil, fmt.Errorf("tokenizer for model %q not found", modelName) - } - return tokenizers.FromFile(path) -} - -func (p *localTokenizerProvider) getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) { - req := preprocessing.FetchChatTemplateRequest{ - IsLocalPath: true, - } - - path, ok := p.cfg.ModelTokenizerMap[modelName] - if !ok { - return req, fmt.Errorf("tokenizer for model %q not found", modelName) - } - req.Model = filepath.Dir(path) - - return req, nil -} - // CompositeTokenizer implements the Tokenizer interface with a fallback mechanism. // It tries each tokenizer in order until one succeeds. This allows for graceful // fallback from local tokenizers to HuggingFace tokenizers. @@ -471,18 +418,18 @@ type CompositeTokenizer struct { Tokenizers []Tokenizer } -func (c *CompositeTokenizer) RenderChatTemplate( - modelName string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (c *CompositeTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, ) (string, error) { var rErr error for _, tokenizer := range c.Tokenizers { - copiedRenderReq, err := renderReq.DeepCopy() + copiedReq, err := req.DeepCopy() if err != nil { rErr = multierr.Append(rErr, fmt.Errorf("failed to copy render request: %w", err)) continue } start := time.Now() - rendered, err := tokenizer.RenderChatTemplate(modelName, copiedRenderReq) + rendered, err := tokenizer.ApplyChatTemplate(modelName, copiedReq) metrics.RenderChatTemplateLatency.WithLabelValues(tokenizer.Type()).Observe(time.Since(start).Seconds()) if err != nil { rErr = multierr.Append(rErr, err) @@ -503,11 +450,16 @@ func (c *CompositeTokenizer) RenderChatTemplate( // 4. If all fail, returns all accumulated errors // // This enables prioritizing local tokenizers while maintaining HuggingFace as a fallback. -func (c *CompositeTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { +func (c *CompositeTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { var rErr error for _, tokenizer := range c.Tokenizers { + copiedReq, err := req.DeepCopy() + if err != nil { + rErr = multierr.Append(rErr, fmt.Errorf("failed to copy encode request: %w", err)) + continue + } start := time.Now() - ids, offsets, err := tokenizer.Encode(input, modelName) + ids, offsets, err := tokenizer.Encode(copiedReq) metrics.TokenizationLatency.WithLabelValues(tokenizer.Type()).Observe(time.Since(start).Seconds()) if err != nil { rErr = multierr.Append(rErr, err) diff --git a/pkg/tokenization/tokenizer_test.go b/pkg/tokenization/tokenizer_test.go index 8822b6f03..17cabc586 100644 --- a/pkg/tokenization/tokenizer_test.go +++ b/pkg/tokenization/tokenizer_test.go @@ -36,13 +36,13 @@ type DummyTokenizer struct { returnError bool } -func (d *DummyTokenizer) RenderChatTemplate( - prompt string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (d *DummyTokenizer) ApplyChatTemplate( + prompt string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { return prompt, nil } -func (d *DummyTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { +func (d *DummyTokenizer) Encode(*preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { if d.returnError { return nil, nil, fmt.Errorf("dummy tokenizer error") } @@ -61,27 +61,36 @@ func TestCachedHFTokenizer_Encode(t *testing.T) { config := &HFTokenizerConfig{ TokenizersCacheDir: t.TempDir(), } - tokenizer, err := NewCachedHFTokenizer(testModelName, config) + tokenizer, err := NewCachedHFTokenizer(config) require.NoError(t, err) require.NotNil(t, tokenizer) tests := []struct { - name string - input string + name string + input string + modelName string }{ { - name: "simple text", - input: "hello world", + name: "simple text", + input: "hello world", + modelName: testModelName, }, { - name: "empty string", - input: "", + name: "empty string", + input: "", + modelName: testModelName, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenIds, offsets, err := tokenizer.Encode(tt.input, testModelName) + tokenIds, offsets, err := tokenizer.Encode( + &preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: tt.modelName, + }, + Text: tt.input, + }) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) @@ -95,7 +104,7 @@ func TestCachedHFTokenizer_CacheTokenizer(t *testing.T) { t.Skip("Skipping tokenizer integration test in short mode") } - tokenizer, err := NewCachedHFTokenizer(testModelName, &HFTokenizerConfig{ + tokenizer, err := NewCachedHFTokenizer(&HFTokenizerConfig{ TokenizersCacheDir: t.TempDir(), }) require.NoError(t, err) @@ -105,11 +114,21 @@ func TestCachedHFTokenizer_CacheTokenizer(t *testing.T) { input := "test input" // First call - loads tokenizer - tokenIds1, offsets1, err1 := tokenizer.Encode(input, testModelName) + tokenIds1, offsets1, err1 := tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testModelName, + }, + Text: input, + }) require.NoError(t, err1) // Second call - should use cached tokenizer - tokenIds2, offsets2, err2 := tokenizer.Encode(input, testModelName) + tokenIds2, offsets2, err2 := tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testModelName, + }, + Text: input, + }) require.NoError(t, err2) // Results should be identical @@ -122,23 +141,32 @@ func TestCachedHFTokenizer_InvalidModel(t *testing.T) { t.Skip("Skipping tokenizer integration test in short mode") } - tokenizer, err := NewCachedHFTokenizer("non-existent/model", &HFTokenizerConfig{ + tokenizer, err := NewCachedHFTokenizer(&HFTokenizerConfig{ TokenizersCacheDir: t.TempDir(), }) + require.NoError(t, err) + require.NotNil(t, tokenizer) - // Assert that an error occurred and tokenizer is nil + // Test with non-existent model + tokenIds, offsets, err := tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: "non-existent/model", + }, + Text: "test", + }) assert.Error(t, err) - assert.Nil(t, tokenizer) + assert.Nil(t, tokenIds) + assert.Nil(t, offsets) } func TestCachedLocalTokenizer_Encode(t *testing.T) { modelName := "test-model" - config := LocalTokenizerConfig{ + config := &LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: "testdata/test-model/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(modelName, config) + tokenizer, err := NewCachedLocalTokenizer(config) require.NoError(t, err) require.NotNil(t, tokenizer) @@ -161,7 +189,12 @@ func TestCachedLocalTokenizer_Encode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenIds, offsets, err := tokenizer.Encode(tt.input, tt.modelName) + tokenIds, offsets, err := tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: tt.modelName, + }, + Text: tt.input, + }) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) @@ -171,28 +204,38 @@ func TestCachedLocalTokenizer_Encode(t *testing.T) { } func TestCachedLocalTokenizer_InvalidModel(t *testing.T) { - modelName := "test-model" invalidModelName := "non-existent-model" - config := LocalTokenizerConfig{ + config := &LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ - modelName: "testdata/test-model/tokenizer.json", + invalidModelName: "testdata/test-model/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(invalidModelName, config) - require.Error(t, err) - require.Nil(t, tokenizer) + tokenizer, err := NewCachedLocalTokenizer(config) + require.NoError(t, err) + require.NotNil(t, tokenizer) } func TestCachedLocalTokenizer_InvalidPath(t *testing.T) { modelName := "invalid-model" - config := LocalTokenizerConfig{ + config := &LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: "testdata/non-existent/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(modelName, config) - require.Error(t, err) - require.Nil(t, tokenizer) + tokenizer, err := NewCachedLocalTokenizer(config) + require.NoError(t, err) + require.NotNil(t, tokenizer) + + // Test with model that points to non-existent file + tokenIds, offsets, err := tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: "invalid-model", + }, + Text: "test", + }) + assert.Error(t, err) + assert.Nil(t, tokenIds) + assert.Nil(t, offsets) } func TestCompositeTokenizer_FallbackBehavior(t *testing.T) { @@ -201,7 +244,7 @@ func TestCompositeTokenizer_FallbackBehavior(t *testing.T) { } dummyTokenizer := &DummyTokenizer{returnError: true} - hfTokenizer, err := NewCachedHFTokenizer(testModelName, &HFTokenizerConfig{ + hfTokenizer, err := NewCachedHFTokenizer(&HFTokenizerConfig{ TokenizersCacheDir: t.TempDir(), }) @@ -210,8 +253,13 @@ func TestCompositeTokenizer_FallbackBehavior(t *testing.T) { composite := &CompositeTokenizer{ Tokenizers: []Tokenizer{dummyTokenizer, hfTokenizer}, } + tokenIds, offsets, err := composite.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: testModelName, + }, + Text: "hello world", + }) - tokenIds, offsets, err := composite.Encode("hello world", testModelName) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) assert.Equal(t, len(tokenIds), len(offsets)) diff --git a/pkg/tokenization/uds_tokenizer.go b/pkg/tokenization/uds_tokenizer.go index d9cc9e1ba..405d3a0bb 100644 --- a/pkg/tokenization/uds_tokenizer.go +++ b/pkg/tokenization/uds_tokenizer.go @@ -105,18 +105,18 @@ func NewUdsTokenizer(config *UdsTokenizerConfig) (Tokenizer, error) { } // Encode tokenizes the input string and returns the token IDs and offsets. -func (u *UdsTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { - req, err := http.NewRequestWithContext( +func (u *UdsTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { + httpReq, err := http.NewRequestWithContext( context.Background(), http.MethodPost, u.baseURL+"/tokenize", - strings.NewReader(input), + strings.NewReader(req.Text), ) if err != nil { return nil, nil, fmt.Errorf("failed to create request: %w", err) } - respBody, err := u.executeRequest(req, defaultTimeout, defaultMaxRetries) + respBody, err := u.executeRequest(httpReq, defaultTimeout, defaultMaxRetries) if err != nil { return nil, nil, fmt.Errorf("tokenize request failed: %w", err) } @@ -129,11 +129,11 @@ func (u *UdsTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.O return tokenized.InputIDs, tokenized.OffsetMapping, nil } -// RenderChatTemplate renders a chat template using the UDS tokenizer service. -func (u *UdsTokenizer) RenderChatTemplate( - _ string, renderReq *preprocessing.RenderJinjaTemplateRequest, +// ApplyChatTemplate renders a chat template using the UDS tokenizer service. +func (u *UdsTokenizer) ApplyChatTemplate( + _ string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { - messagesBytes, err := json.Marshal(renderReq.Conversations) + messagesBytes, err := json.Marshal(renderReq.Conversation) if err != nil { return "", fmt.Errorf("failed to marshal chat-completions messages: %w", err) } diff --git a/tests/e2e/redis_mock/e2e_suite_test.go b/tests/e2e/redis_mock/e2e_suite_test.go index 09adc4341..393d2ad53 100644 --- a/tests/e2e/redis_mock/e2e_suite_test.go +++ b/tests/e2e/redis_mock/e2e_suite_test.go @@ -19,11 +19,13 @@ package e2e import ( "context" + "path/filepath" "testing" "github.com/go-logr/logr/testr" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "github.com/llm-d/llm-d-kv-cache/pkg/tokenization" "github.com/llm-d/llm-d-kv-cache/pkg/utils" "github.com/stretchr/testify/suite" @@ -68,15 +70,26 @@ func (s *KVCacheSuite) SetupTest() { s.config, err = kvcache.NewDefaultConfig() s.Require().NoError(err) - s.config.TokenizersPoolConfig.ModelName = defaultModelName s.config.PrefixStoreConfig.BlockSize = 4 s.config.TokenProcessorConfig.BlockSize = 4 - hfTokenizer, err := tokenization.NewCachedHFTokenizer(defaultModelName, s.config.TokenizersPoolConfig.HFTokenizerConfig) + // Configure the indexer's tokenization pool to support local models + // This is needed because GetPodScores uses the indexer's internal pool for tokenization + testDataPath, err := filepath.Abs("testdata") + s.Require().NoError(err) + + s.config.TokenizersPoolConfig.LocalTokenizerConfig.AutoDiscoveryDir = testDataPath + + localTokenizer, err := tokenization.NewCachedLocalTokenizer(s.config.TokenizersPoolConfig.LocalTokenizerConfig) + s.Require().NoError(err) + + hfTokenizer, err := tokenization.NewCachedHFTokenizer(s.config.TokenizersPoolConfig.HFTokenizerConfig) s.Require().NoError(err) // Use composite tokenizer: try local first, then fall back to HF - s.tokenizer = hfTokenizer + s.tokenizer = &tokenization.CompositeTokenizer{ + Tokenizers: []tokenization.Tokenizer{localTokenizer, hfTokenizer}, + } s.tokensProcessor = kvblock.NewChunkedTokenDatabase(s.config.TokenProcessorConfig) @@ -96,7 +109,11 @@ func (s *KVCacheSuite) SetupTest() { func (s *KVCacheSuite) promptToEngineAndRequestKeys( prompt, model string, ) (engineKeys, requestKeys []kvblock.Key) { - tokens, _, err := s.tokenizer.Encode(prompt, model) + tokens, _, err := s.tokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{Model: model}, + Text: prompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) requestKeys = s.tokensProcessor.TokensToKVBlockKeys(nil, tokens, model) @@ -122,11 +139,6 @@ func (s *KVCacheSuite) addEntriesToIndex(engineKeys, requestKeys []kvblock.Key, s.Require().NoError(err) } -func (s *KVCacheSuite) SetTokenizer(tokenizer tokenization.Tokenizer, modelName string) { - s.tokenizer = tokenizer - s.indexer.SetTokenizer(tokenizer, modelName) -} - // TestKVCacheSuite runs the KVCacheSuite using testify's suite runner. func TestKVCacheSuite(t *testing.T) { suite.Run(t, new(KVCacheSuite)) diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index d3f3f9899..830fecb8d 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -28,6 +28,11 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/tokenization" ) +const ( + localTestModelDir = "testdata/test-model" + localLlama3ModelDir = "testdata/local-llama3" +) + // ChatMessage represents a single message in a conversation. type ChatMessage struct { Role string `json:"role"` @@ -36,29 +41,14 @@ type ChatMessage struct { // ChatTemplateRequest represents the request to render a chat template. type ChatTemplateRequest struct { - Conversations [][]ChatMessage `json:"conversations"` - ChatTemplate string `json:"chatTemplate"` - TemplateVars map[string]interface{} `json:"templateVars,omitempty"` -} - -// ChatTemplateResponse represents the response from the Python function. -type ChatTemplateResponse struct { - RenderedChats []string `json:"renderedChats"` - GenerationIndices [][][]int `json:"generationIndices"` -} - -// GetChatTemplateRequest represents the request to get a model's chat template. -type GetChatTemplateRequest struct { - ModelName string `json:"modelName"` - Revision string `json:"revision,omitempty"` - Token string `json:"token,omitempty"` + Conversation [][]ChatMessage `json:"conversation"` } // convertToPreprocessingChatMessages converts e2e ChatMessage to preprocessing ChatMessage. -func convertToPreprocessingChatMessages(messages []ChatMessage) []preprocessing.ChatMessage { - result := make([]preprocessing.ChatMessage, len(messages)) +func convertToPreprocessingChatMessages(messages []ChatMessage) []preprocessing.Conversation { + result := make([]preprocessing.Conversation, len(messages)) for i, msg := range messages { - result[i] = preprocessing.ChatMessage{ + result[i] = preprocessing.Conversation{ Role: msg.Role, Content: msg.Content, } @@ -73,24 +63,10 @@ func NewMockChatTemplateWrapper() *MockChatTemplateWrapper { return &MockChatTemplateWrapper{} } -//nolint:nonamedreturns // Mock implementation uses named returns for clarity and consistency with interface. -func (w *MockChatTemplateWrapper) GetModelChatTemplate( - req GetChatTemplateRequest, -) (template string, templateVars map[string]interface{}, err error) { - // Mock implementation that returns a simple template. - template = `{% for message in messages %}{{ message.role }}: {{ message.content }} -{% endfor %}` - templateVars = map[string]interface{}{ - "bos_token": "", - "eos_token": "", - } - return template, templateVars, nil -} - -func (w *MockChatTemplateWrapper) RenderChatTemplate(req ChatTemplateRequest) (*ChatTemplateResponse, error) { +func (w *MockChatTemplateWrapper) ApplyChatTemplate(req *ChatTemplateRequest) (string, error) { // Mock implementation that renders the template. - renderedChats := make([]string, 0, len(req.Conversations)) - for _, conversation := range req.Conversations { + renderedChats := make([]string, 0, len(req.Conversation)) + for _, conversation := range req.Conversation { rendered := "" for _, message := range conversation { rendered += message.Role + ": " + message.Content + "\n" @@ -98,10 +74,7 @@ func (w *MockChatTemplateWrapper) RenderChatTemplate(req ChatTemplateRequest) (* renderedChats = append(renderedChats, rendered) } - return &ChatTemplateResponse{ - RenderedChats: renderedChats, - GenerationIndices: [][][]int{}, - }, nil + return strings.Join(renderedChats, "\n"), nil } // TestBasicE2E verifies that the indexer initially returns no scores for the first prompt and @@ -257,30 +230,15 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { }, } - // Step 1: Get the model's chat template. - templateRequest := GetChatTemplateRequest{ - ModelName: "ibm-granite/granite-3.3-8b-instruct", + // Step 1: Render the conversation using the template. + renderRequest := &ChatTemplateRequest{ + Conversation: conversation, } - template, templateVars, err := wrapper.GetModelChatTemplate(templateRequest) - s.Require().NoError(err, "Failed to get model chat template") - s.Require().NotEmpty(template, "ChatTemplate should not be empty") - - // Step 2: Render the conversation using the template. - renderRequest := ChatTemplateRequest{ - Conversations: conversation, - ChatTemplate: template, - TemplateVars: templateVars, - } - response, err := wrapper.RenderChatTemplate(renderRequest) + flattenedPrompt, err := wrapper.ApplyChatTemplate(renderRequest) s.Require().NoError(err, "Failed to render chat template") - s.Require().NotNil(response, "Response should not be nil") - s.Require().NotEmpty(response.RenderedChats, "Rendered chats should not be empty") + s.Require().NotEmpty(flattenedPrompt, "Response should not be empty") - // Step 3: Extract the flattened prompt from the rendered template. - flattenedPrompt := response.RenderedChats[0] - s.Require().NotEmpty(flattenedPrompt, "Flattened prompt should not be empty") - - // Step 4: Use the flattened prompt for KV-cache lookup (similar to TestBasicE2E). + // Step 2: Use the flattened prompt for KV-cache lookup (similar to TestBasicE2E). engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") fakePodList := []string{s.Pod1IP} @@ -330,31 +288,16 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { }, } - // Step 1: Get the model's chat template. - templateRequest := GetChatTemplateRequest{ - ModelName: "ibm-granite/granite-3.3-8b-instruct", - } - template, templateVars, err := wrapper.GetModelChatTemplate(templateRequest) - s.Require().NoError(err, "Failed to get model chat template") - s.Require().NotEmpty(template, "ChatTemplate should not be empty") - - // Step 2: Render the long conversation. - renderRequest := ChatTemplateRequest{ - Conversations: longConversation, - ChatTemplate: template, - TemplateVars: templateVars, + // Step 1: Render the long conversation. + renderRequest := &ChatTemplateRequest{ + Conversation: longConversation, } - response, err := wrapper.RenderChatTemplate(renderRequest) + flattenedPrompt, err := wrapper.ApplyChatTemplate(renderRequest) s.Require().NoError(err, "Failed to render long conversation") - s.Require().NotNil(response, "Response should not be nil") - s.Require().NotEmpty(response.RenderedChats, "Rendered chats should not be empty") - - // Step 3: Extract the flattened prompt. - flattenedPrompt := response.RenderedChats[0] s.Require().NotEmpty(flattenedPrompt, "Flattened prompt should not be empty") s.Require().Greater(len(flattenedPrompt), 1000, "Long conversation should produce substantial output") - // Step 4: Test KV-cache with the long flattened prompt. + // Step 3: Test KV-cache with the long flattened prompt. engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") fakePodList := []string{s.Pod1IP} @@ -381,21 +324,26 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { // Create a local tokenizer using the testdata modelName := "test-model" - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ - modelName: "testdata/test-model/tokenizer.json", + modelName: localTestModelDir + "/tokenizer.json", }, }) s.Require().NoError(err) s.Require().NotNil(localTokenizer) - s.SetTokenizer(localTokenizer, modelName) - prompt := "What is the capital of France?" fakePodList := []string{s.Pod1IP} // Tokenize using local tokenizer - tokens, offsets, err := localTokenizer.Encode(prompt, modelName) + tokens, offsets, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: modelName, + IsLocal: true, + }, + Text: prompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets), "tokens and offsets should have same length") @@ -415,7 +363,14 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { s.T().Logf("GetPodScores returned score: %v", pods[s.Pod1IP]) // Also verify that tokenizing the same prompt again produces same block keys - tokens2, _, err := localTokenizer.Encode(prompt, modelName) + tokens2, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: modelName, + IsLocal: true, + }, + Text: prompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, modelName) s.Require().Equal(requestKeys, requestKeys2, "same prompt should produce same block keys") @@ -434,43 +389,51 @@ func (s *KVCacheSuite) TestHFCacheStructureDiscoveryE2E() { testModelPath := filepath.Join(tmpDir, "models--test-org--test-model", "snapshots", "abc123") require.NoError(s.T(), os.MkdirAll(testModelPath, 0o755)) - // Copy the test tokenizer file - srcTokenizer := "testdata/test-model/tokenizer.json" - dstTokenizer := filepath.Join(testModelPath, "tokenizer.json") - srcData, err := os.ReadFile(srcTokenizer) - require.NoError(s.T(), err) - require.NoError(s.T(), os.WriteFile(dstTokenizer, srcData, 0o600)) + // Copy the test tokenizer + require.NoError(s.T(), os.CopyFS(testModelPath, os.DirFS(localTestModelDir))) // Create tokenizer config with auto-discovery - config := tokenization.LocalTokenizerConfig{ + config := &tokenization.LocalTokenizerConfig{ AutoDiscoveryDir: tmpDir, AutoDiscoveryTokenizerFileName: "tokenizer.json", } - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, config) + localTokenizer, err := tokenization.NewCachedLocalTokenizer(config) s.Require().NoError(err) s.Require().NotNil(localTokenizer) - s.SetTokenizer(localTokenizer, modelName) - prompt := "What is the capital of France?" fakePodList := []string{s.Pod1IP} // Tokenize using the auto-discovered HF cache tokenizer - tokens, offsets, err := localTokenizer.Encode(prompt, modelName) + tokens, offsets, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: modelName, + IsLocal: true, + }, + Text: prompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets), "tokens and offsets should have same length") s.T().Logf("HF cache auto-discovery produced %d tokens for model %q", len(tokens), modelName) - // Convert tokens to KV block keys using promptToEngineAndRequestKeys with local tokenizer - engineKeys1, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName) + // Convert tokens to KV block keys + requestKeys := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens, modelName) + s.Require().NotEmpty(requestKeys) // Add entries to the index - s.addEntriesToIndex(engineKeys1, requestKeys, fakePodList) - + s.addEntriesToIndex(requestKeys, requestKeys, fakePodList) // Verify retrieval - tokens2, _, err := localTokenizer.Encode(prompt, modelName) + tokens2, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + Model: modelName, + IsLocal: true, + }, + Text: prompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, modelName) s.Require().Equal(requestKeys, requestKeys2, "same prompt should produce same block keys") @@ -488,12 +451,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -504,7 +467,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), }, @@ -512,8 +475,6 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.Require().NoError(err) s.Require().NotNil(localTokenizer) - s.SetTokenizer(localTokenizer, tc.modelName) - // Test conversation conversation := []ChatMessage{ {Role: "user", Content: "What is machine learning?"}, @@ -523,11 +484,15 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { // Step 1: Render the conversation into a flattened prompt using local chat template // This tests the full integration: Go -> CGO -> Python -> Local Tokenizer - renderReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + renderReq := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(conversation), } - renderedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, renderReq) - s.Require().NoError(err, "RenderChatTemplate should succeed with local tokenizer") + renderedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, renderReq) + s.Require().NoError(err, "ApplyChatTemplate should succeed with local tokenizer") s.Require().NotEmpty(renderedPrompt, "Rendered prompt should not be empty") s.T().Logf("Rendered prompt from local template:\n%s", renderedPrompt) @@ -537,7 +502,14 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.Require().Contains(renderedPrompt, "Give me an example", "rendered prompt should contain second user message") // Step 2: Tokenize the rendered prompt using the same local tokenizer - tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName) + tokens, offsets, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: renderedPrompt, + AddSpecialTokens: true, + }) s.Require().NoError(err, "Encode should succeed") s.Require().NotEmpty(tokens, "Tokens should not be empty") s.Require().Equal(len(tokens), len(offsets), "Tokens and offsets should have same length") @@ -558,14 +530,25 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.T().Logf("GetPodScores returned score: %v for rendered chat template", pods[s.Pod1IP]) // Also verify by rendering and tokenizing the same conversation again - renderReq2 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + renderReq2 := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(conversation), } - renderedPrompt2, err := localTokenizer.RenderChatTemplate(tc.modelName, renderReq2) + renderedPrompt2, err := localTokenizer.ApplyChatTemplate(tc.modelName, renderReq2) s.Require().NoError(err) s.Require().Equal(renderedPrompt, renderedPrompt2, "Same conversation should render identically") - tokens2, _, err := localTokenizer.Encode(renderedPrompt2, tc.modelName) + tokens2, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: renderedPrompt2, + AddSpecialTokens: true, + }) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, tc.modelName) s.Require().Equal(requestKeys, requestKeys2, "Same conversation should produce same block keys") @@ -584,12 +567,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -599,15 +582,13 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), }, }) s.Require().NoError(err) - s.SetTokenizer(localTokenizer, tc.modelName) - fakePodList := []string{s.Pod1IP} // Start with a short conversation @@ -618,13 +599,24 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { } // Render and cache the short conversation - shortReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(shortConversation), + shortReq := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(shortConversation), } - shortPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, shortReq) + shortPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, shortReq) s.Require().NoError(err) s.T().Logf("Short prompt length: %d chars", len(shortPrompt)) - shortTokens, _, err := localTokenizer.Encode(shortPrompt, tc.modelName) + shortTokens, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: shortPrompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) shortEngineKeys, shortRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, tc.modelName) s.addEntriesToIndex(shortEngineKeys, shortRequestKeys, fakePodList) @@ -645,15 +637,26 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { } // Render and test the extended conversation - extendedReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(extendedConversation), + extendedReq := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(extendedConversation), } - extendedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, extendedReq) + extendedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, extendedReq) s.Require().NoError(err) s.T().Logf("Extended prompt: %q (length: %d)", extendedPrompt, len(extendedPrompt)) s.Require().Greater(len(extendedPrompt), len(shortPrompt), "Extended conversation should be longer") - extendedTokens, _, err := localTokenizer.Encode(extendedPrompt, tc.modelName) + extendedTokens, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: extendedPrompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) extendedEngineKeys, extendedRequestKeys := s.promptToEngineAndRequestKeys(extendedPrompt, tc.modelName) @@ -703,12 +706,12 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -728,30 +731,39 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { s.Require().FileExists(filepath.Join(testModelDir, "config.json"), "config.json should exist") s.Require().FileExists(filepath.Join(testModelDir, "tokenizer.json"), "tokenizer.json should exist") - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), }, }) s.Require().NoError(err) - s.SetTokenizer(localTokenizer, tc.modelName) - conversation := []ChatMessage{ {Role: "user", Content: "Test message"}, {Role: "assistant", Content: "Test response"}, } // Render with local tokenizer - req1 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + req1 := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(conversation), } - localRendered, err := localTokenizer.RenderChatTemplate(tc.modelName, req1) + localRendered, err := localTokenizer.ApplyChatTemplate(tc.modelName, req1) s.Require().NoError(err) s.Require().NotEmpty(localRendered) // Tokenize with local tokenizer - localTokens, _, err := localTokenizer.Encode(localRendered, tc.modelName) + localTokens, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: localRendered, + AddSpecialTokens: true, + }) s.Require().NoError(err) s.T().Logf("Local tokenizer: rendered=%d chars, tokens=%d", len(localRendered), len(localTokens)) @@ -767,16 +779,27 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { s.T().Logf("GetPodScores returned score: %v", pods[s.Pod1IP]) // Render the same conversation again to test caching and consistency - req2 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + req2 := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(conversation), } - localRendered2, err := localTokenizer.RenderChatTemplate(tc.modelName, req2) + localRendered2, err := localTokenizer.ApplyChatTemplate(tc.modelName, req2) s.Require().NoError(err) s.Require().Equal(localRendered, localRendered2, "Rendering the same conversation twice should produce identical output (tests caching)") // Tokenize again - localTokens2, _, err := localTokenizer.Encode(localRendered2, tc.modelName) + localTokens2, _, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: localRendered2, + AddSpecialTokens: true, + }) s.Require().NoError(err) s.Require().Equal(localTokens, localTokens2, "Tokenizing the same prompt twice should produce identical tokens") @@ -789,36 +812,38 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { // TestLocalTokenizerChatTemplateErrorHandling tests error cases for local chat templates. func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { modelName := "test-model" - testModelDir, err := filepath.Abs("testdata/test-model") + testModelDir, err := filepath.Abs(localTestModelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: filepath.Join(testModelDir, "tokenizer.json"), }, }) s.Require().NoError(err) - s.SetTokenizer(localTokenizer, modelName) - conversation := []ChatMessage{ {Role: "user", Content: "Test"}, } // Test 1: Non-existent model - reqNonExistent := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + reqNonExistent := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingChatMessages(conversation), } - _, err = localTokenizer.RenderChatTemplate("non-existent-model", reqNonExistent) + _, err = localTokenizer.ApplyChatTemplate("non-existent-model", reqNonExistent) s.Require().Error(err, "Should return error for non-existent model") s.T().Logf("Expected error for non-existent model: %v", err) // Test 2: Empty conversation emptyConversation := []ChatMessage{} - reqEmpty := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(emptyConversation), + reqEmpty := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: modelName, + }, + Conversation: convertToPreprocessingChatMessages(emptyConversation), } - rendered, err := localTokenizer.RenderChatTemplate("test-model", reqEmpty) + rendered, err := localTokenizer.ApplyChatTemplate(modelName, reqEmpty) // This might succeed with empty output or fail depending on template // Either is acceptable behavior if err == nil { @@ -839,12 +864,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -854,18 +879,16 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(&tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), }, }) s.Require().NoError(err) - s.SetTokenizer(localTokenizer, tc.modelName) - // Create a very long conversation (100 turns) longConversation := make([]ChatMessage, 0, 200) - for i := 0; i < 100; i++ { + for range 100 { longConversation = append(longConversation, ChatMessage{ Role: "user", @@ -879,17 +902,28 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { } // Render the long conversation - reqLong := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(longConversation), + reqLong := &preprocessing.ApplyChatTemplateRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Conversation: convertToPreprocessingChatMessages(longConversation), } - renderedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, reqLong) + renderedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, reqLong) s.Require().NoError(err) s.Require().NotEmpty(renderedPrompt) s.Require().Greater(len(renderedPrompt), 1000, "Long conversation should produce substantial output") s.T().Logf("Long conversation rendered to %d characters", len(renderedPrompt)) // Tokenize - tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName) + tokens, offsets, err := localTokenizer.Encode(&preprocessing.EncodeRequest{ + ChatTemplateRequest: preprocessing.ChatTemplateRequest{ + IsLocal: true, + Model: tc.modelName, + }, + Text: renderedPrompt, + AddSpecialTokens: true, + }) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets)) From 14c7787088b34fe2ee8a917692911c6d0b694e53 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 18 Dec 2025 16:13:03 +0000 Subject: [PATCH 2/5] decoupling daulet/tokenizers --- .gitignore | 3 -- Dockerfile | 10 ++---- Makefile | 32 ++++--------------- docs/architecture.md | 2 -- go.mod | 1 - go.sum | 2 -- .../chat_completions/cgo_functions.go | 7 ++-- pkg/preprocessing/chat_completions/type.go | 22 +++++++++++++ pkg/tokenization/pool_test.go | 11 +++---- pkg/tokenization/prefixstore/indexer.go | 6 ++-- pkg/tokenization/prefixstore/lru_store.go | 4 +-- .../prefixstore/lru_store_test.go | 6 ++-- pkg/tokenization/prefixstore/trie_store.go | 6 ++-- pkg/tokenization/tokenizer.go | 9 +++--- pkg/tokenization/tokenizer_test.go | 5 ++- pkg/tokenization/uds_tokenizer.go | 7 ++-- 16 files changed, 57 insertions(+), 76 deletions(-) create mode 100644 pkg/preprocessing/chat_completions/type.go diff --git a/.gitignore b/.gitignore index b9df47691..fdb97fffd 100644 --- a/.gitignore +++ b/.gitignore @@ -65,9 +65,6 @@ _cgo_* .gopls/ /hack/tools -# Tokenizer binaries -/lib - # UDS tokenizer files services/uds_tokenizer/models/* !services/uds_tokenizer/models/README.md \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 3e5d08744..10de7d698 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,12 +45,6 @@ RUN python3.12 -m pip install --upgrade pip setuptools wheel && \ COPY examples/kv_events examples/kv_events COPY . . -# HuggingFace tokenizer bindings -RUN mkdir -p lib -ARG RELEASE_VERSION=v1.22.1 -RUN curl -L https://github.com/daulet/tokenizers/releases/download/${RELEASE_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib -RUN ranlib lib/*.a - # Set up Python environment variables needed for the build ENV PYTHONPATH=/workspace/pkg/preprocessing/chat_completions:/usr/lib64/python3.9/site-packages:/usr/lib/python3.9/site-packages ENV PYTHON=python3.9 @@ -58,8 +52,8 @@ ENV PYTHON=python3.9 # Build the application with CGO enabled. # We export CGO_CFLAGS and CGO_LDFLAGS using python3.12-config to ensure the Go compiler # can find the Python headers and libraries correctly. This mirrors the fix from the Makefile. -RUN export CGO_CFLAGS="$(python3.12-config --cflags) -I/workspace/lib" && \ - export CGO_LDFLAGS="$(python3.12-config --ldflags --embed) -L/workspace/lib -ltokenizers -ldl -lm" && \ +RUN export CGO_CFLAGS="$(python3.12-config --cflags)" && \ + export CGO_LDFLAGS="$(python3.12-config --ldflags --embed) -ldl -lm" && \ CGO_ENABLED=1 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH:-amd64} \ go build -a -o bin/kv-cache-manager examples/kv_events/online/main.go diff --git a/Makefile b/Makefile index 58914bb61..7e1aa9f13 100644 --- a/Makefile +++ b/Makefile @@ -23,26 +23,6 @@ SRC = $(shell find . -type f -name '*.go') help: ## Print help @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) -##@ Tokenizer & Linking - -TOKENIZER_LIB = lib/libtokenizers.a - -# Extract RELEASE_VERSION from Dockerfile -TOKENIZER_VERSION := $(shell grep '^ARG RELEASE_VERSION=' Dockerfile | cut -d'=' -f2) - -.PHONY: download-tokenizer -download-tokenizer: $(TOKENIZER_LIB) -$(TOKENIZER_LIB): - ## Download the HuggingFace tokenizer bindings. - @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." - mkdir -p lib - if [ "$(TARGETOS)" = "darwin" ] && [ "$(TARGETARCH)" = "amd64" ]; then \ - curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-x86_64.tar.gz | tar -xz -C lib; \ - else \ - curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib; \ - fi - ranlib lib/*.a - ##@ Python Configuration PYTHON_VERSION := 3.12 @@ -87,8 +67,8 @@ else endif # Final CGO flags with all dependencies -CGO_CFLAGS_FINAL := $(PYTHON_CFLAGS) -Ilib -CGO_LDFLAGS_FINAL := $(PYTHON_LDFLAGS) $(PYTHON_LIBS) -Llib -ltokenizers -ldl -lm +CGO_CFLAGS_FINAL := $(PYTHON_CFLAGS) +CGO_LDFLAGS_FINAL := $(PYTHON_LDFLAGS) $(PYTHON_LIBS) -ldl -lm .PHONY: detect-python detect-python: ## Detects Python and prints the configuration. @@ -187,17 +167,17 @@ export PYTHONPATH=$(shell pwd)/pkg/preprocessing/chat_completions:$(VENV_DIR)/li test: unit-test e2e-test ## Run all tests .PHONY: unit-test -unit-test: download-tokenizer install-python-deps download-zmq ## Run unit tests +unit-test: install-python-deps download-zmq ## Run unit tests @printf "\033[33;1m==== Running unit tests ====\033[0m\n" @go test -v ./pkg/... .PHONY: e2e-test -e2e-test: download-tokenizer download-local-llama3 install-python-deps download-zmq ## Run end-to-end tests +e2e-test: download-local-llama3 install-python-deps download-zmq ## Run end-to-end tests @printf "\033[33;1m==== Running e2e tests ====\033[0m\n" @go test -v ./tests/... .PHONY: bench -bench: download-tokenizer install-python-deps download-zmq ## Run benchmarks +bench: install-python-deps download-zmq ## Run benchmarks @printf "\033[33;1m==== Running chat template benchmarks ====\033[0m\n" @go test -bench=. -benchmem ./pkg/preprocessing/chat_completions/ @printf "\033[33;1m==== Running tokenization benchmarks ====\033[0m\n" @@ -211,7 +191,7 @@ run: build ## Run the application locally ##@ Build .PHONY: build -build: check-go download-tokenizer install-python-deps download-zmq ## Build the application binary +build: check-go install-python-deps download-zmq ## Build the application binary @printf "\033[33;1m==== Building application binary ====\033[0m\n" @go build -o bin/$(PROJECT_NAME) examples/kv_events/online/main.go @echo "✅ Built examples/kv_events/online/main.go -> bin/$(PROJECT_NAME)" diff --git a/docs/architecture.md b/docs/architecture.md index 8dadfa741..3df10129a 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -173,8 +173,6 @@ Efficiently handling tokenization is critical for performance. The system is des ## Dependencies The Indexer relies on several libraries and tools: -* **[daulet/tokenizers](https://github.com/daulet/tokenizers)**: Go bindings for the HuggingFace Tokenizers library. - * Used for tokenization of prompts. * **[pebbe/zmq4](https://github.com/pebbe/zmq4)**: Go bindings for ZeroMQ. * Used for the event processing pool and communication between components. * Requires `libzmq` library to be installed on the system. diff --git a/go.mod b/go.mod index 12bf6f95c..e1572a115 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.24.1 require ( github.com/alicebob/miniredis/v2 v2.35.0 github.com/cespare/xxhash/v2 v2.3.0 - github.com/daulet/tokenizers v1.22.1 github.com/dgraph-io/ristretto/v2 v2.3.0 github.com/dustin/go-humanize v1.0.1 github.com/fxamacker/cbor/v2 v2.7.0 diff --git a/go.sum b/go.sum index c3450867d..3e34e5df4 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,6 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0 github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/daulet/tokenizers v1.22.1 h1:3wzAFIxfgRuqGKka8xdkeTbctDmmqOOs12GofqdorpM= -github.com/daulet/tokenizers v1.22.1/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index c2600462f..63818b46f 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -32,7 +32,6 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/utils/logging" "sigs.k8s.io/controller-runtime/pkg/log" ) -import "github.com/daulet/tokenizers" // Conversation represents a single message in a conversation. type Conversation struct { @@ -98,8 +97,8 @@ func (req *EncodeRequest) DeepCopy() (*EncodeRequest, error) { } type EncodeResponse struct { - TokenIDs []uint32 `json:"input_ids"` - OffsetMappings []tokenizers.Offset `json:"offset_mapping"` + TokenIDs []uint32 `json:"input_ids"` + OffsetMappings []Offset `json:"offset_mapping"` } // ChatTemplatingProcessor is a processor that handles chat template rendering @@ -169,7 +168,7 @@ func (w *ChatTemplatingProcessor) ApplyChatTemplate(ctx context.Context, func (w *ChatTemplatingProcessor) Encode( ctx context.Context, req *EncodeRequest, -) ([]uint32, []tokenizers.Offset, error) { +) ([]uint32, []Offset, error) { traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("Encode") // Convert request to JSON reqJSON, err := json.Marshal(req) diff --git a/pkg/preprocessing/chat_completions/type.go b/pkg/preprocessing/chat_completions/type.go new file mode 100644 index 000000000..549ae7540 --- /dev/null +++ b/pkg/preprocessing/chat_completions/type.go @@ -0,0 +1,22 @@ +package preprocessing + +import "unsafe" + +// Tokenizer is a thin wrapper around an underlying tokenizer implementation. +// It holds an opaque handle to a tokenizer instance, which is provided and +// managed by external code and must remain valid for the lifetime of this +// memory itself. +type Tokenizer struct { + // tokenizer points to the underlying tokenizer instance. The memory it + // references is owned and managed by the caller or another subsystem and + // must outlive any use of this Tokenizer. This wrapper never takes + // ownership of the pointer and must not free it. + // tokenizer is an opaque handle to the underlying tokenizer implementation. + // It is stored as unsafe.Pointer because the actual object is managed outside + // of Go (for example by an external library via cgo) and its concrete type is + // not exposed here. This field must not be dereferenced directly in Go code. + tokenizer unsafe.Pointer +} + +// Offset represents a character offset range with [start, end] indices. +type Offset [2]uint diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index 88038cd2a..38dcd28bc 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -24,7 +24,6 @@ import ( "testing" "time" - "github.com/daulet/tokenizers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -57,7 +56,7 @@ func (m *MockTokenizer) ApplyChatTemplate( return args.String(0), args.Error(1) } -func (m *MockTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (m *MockTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { args := m.Called(req) tokenIface := args.Get(0) if tokenIface == nil { @@ -71,9 +70,9 @@ func (m *MockTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []to if offsetIface == nil { return nil, nil, args.Error(2) } - offsets, ok := offsetIface.([]tokenizers.Offset) + offsets, ok := offsetIface.([]preprocessing.Offset) if !ok { - panic("MockTokenizer.Encode: expected []tokenizers.Offset from mock, got unexpected type") + panic("MockTokenizer.Encode: expected []preprocessing.Offset from mock, got unexpected type") } return tokens, offsets, args.Error(2) } @@ -87,7 +86,7 @@ type MockIndexer struct { mock.Mock } -func (m *MockIndexer) AddTokenization(prompt string, tokens []uint32, offsets []tokenizers.Offset) error { +func (m *MockIndexer) AddTokenization(prompt string, tokens []uint32, offsets []preprocessing.Offset) error { args := m.Called(prompt, tokens, offsets) return args.Error(0) } @@ -121,7 +120,7 @@ func TestPool_ProcessTask(t *testing.T) { // Setup specific mock return values expectedTokens := []uint32{12345, 67890, 11111} - expectedOffsets := []tokenizers.Offset{{0, 5}, {6, 11}} + expectedOffsets := []preprocessing.Offset{{0, 5}, {6, 11}} // Mock FindLongestContainedTokens to return low overlap ratio mockIndexer.On("FindLongestContainedTokens", task.Prompt).Return([]uint32{}, 0.0) diff --git a/pkg/tokenization/prefixstore/indexer.go b/pkg/tokenization/prefixstore/indexer.go index 8510868ee..fcc79f3a3 100644 --- a/pkg/tokenization/prefixstore/indexer.go +++ b/pkg/tokenization/prefixstore/indexer.go @@ -16,9 +16,7 @@ limitations under the License. package prefixstore -import ( - "github.com/daulet/tokenizers" -) +import preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" // Config holds the configuration for the Indexer module. type Config struct { @@ -41,7 +39,7 @@ type Indexer interface { // indexer. // The function assumes tokens and offsets are of the same length. // The function assumes that tokens will not be mutated after the call. - AddTokenization(prompt string, tokens []uint32, offsets []tokenizers.Offset) error + AddTokenization(prompt string, tokens []uint32, offsets []preprocessing.Offset) error // FindLongestContainedTokens finds the sequence of contained tokens for // the longest matching prefix, along with the coverage ratio of the prompt. FindLongestContainedTokens(prompt string) ([]uint32, float64) diff --git a/pkg/tokenization/prefixstore/lru_store.go b/pkg/tokenization/prefixstore/lru_store.go index 6d0ac973b..2bc791f9f 100644 --- a/pkg/tokenization/prefixstore/lru_store.go +++ b/pkg/tokenization/prefixstore/lru_store.go @@ -22,8 +22,8 @@ import ( "sync" "github.com/cespare/xxhash/v2" - "github.com/daulet/tokenizers" lru "github.com/hashicorp/golang-lru/v2" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" ) const ( @@ -91,7 +91,7 @@ func NewLRUTokenStore(config *Config) (Indexer, error) { // The function assumes tokens and offsets are of the same length. // The function assumes that tokens will not be mutated after the call. func (c *LRUTokenStore) AddTokenization(prompt string, tokens []uint32, - offsets []tokenizers.Offset, + offsets []preprocessing.Offset, ) error { if prompt == "" || len(tokens) == 0 { return nil diff --git a/pkg/tokenization/prefixstore/lru_store_test.go b/pkg/tokenization/prefixstore/lru_store_test.go index aa57d3b25..9715e80f8 100644 --- a/pkg/tokenization/prefixstore/lru_store_test.go +++ b/pkg/tokenization/prefixstore/lru_store_test.go @@ -20,7 +20,7 @@ import ( "strings" "testing" - "github.com/daulet/tokenizers" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "github.com/stretchr/testify/assert" ) @@ -34,7 +34,7 @@ func setupTestLRUTokenStore(t *testing.T, blockSize int) (*LRUTokenStore, string text := "The capital of France is Paris" tokens := []uint32{1, 2, 3, 4, 5, 6} - offsets := []tokenizers.Offset{ + offsets := []preprocessing.Offset{ {0, 3}, {4, 11}, {12, 14}, {15, 21}, {22, 24}, {25, 30}, } @@ -138,7 +138,7 @@ func TestLRUTokenStore_LRUEviction(t *testing.T) { {4, 5, 6}, {7, 8, 9}, } - offsets := [][]tokenizers.Offset{ + offsets := [][]preprocessing.Offset{ {{0, 5}, {6, 10}, {11, 15}}, {{0, 6}, {7, 12}, {13, 18}}, {{0, 6}, {7, 12}, {13, 18}}, diff --git a/pkg/tokenization/prefixstore/trie_store.go b/pkg/tokenization/prefixstore/trie_store.go index fed7beeda..4c777fd32 100644 --- a/pkg/tokenization/prefixstore/trie_store.go +++ b/pkg/tokenization/prefixstore/trie_store.go @@ -19,7 +19,7 @@ package prefixstore import ( "sync" - "github.com/daulet/tokenizers" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" ) // containedTokenNode represents a node in the character-based Trie. @@ -57,7 +57,7 @@ var _ Indexer = &TrieTokenStore{} // The function assumes tokens and offsets are of the same length. // The function assumes that tokens will not be mutated after the call. func (t *TrieTokenStore) AddTokenization(prompt string, tokens []uint32, - offsets []tokenizers.Offset, + offsets []preprocessing.Offset, ) error { if prompt == "" || len(tokens) == 0 || len(tokens) != len(offsets) { return nil @@ -82,7 +82,7 @@ func NewContainedTokenTrie() *TrieTokenStore { // It iterates through characters and determines the last contained token at // each step. // Assumes the caller holds the Write Lock. -func (t *TrieTokenStore) addFullTokenization(prompt string, tokens []uint32, offsets []tokenizers.Offset) { +func (t *TrieTokenStore) addFullTokenization(prompt string, tokens []uint32, offsets []preprocessing.Offset) { node := t.root var lastFoundK int if len(tokens) > 0 { diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 357c13c4a..f5d9b4d7f 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -25,7 +25,6 @@ import ( "strings" "time" - "github.com/daulet/tokenizers" "go.uber.org/multierr" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/metrics" @@ -36,7 +35,7 @@ import ( type Tokenizer interface { ApplyChatTemplate(string, *preprocessing.ApplyChatTemplateRequest) (string, error) // Encode tokenizes the input string and returns the token IDs and offsets. - Encode(*preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) + Encode(*preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) Type() string } @@ -351,7 +350,7 @@ func (t *LocalCachedTokenizer) ApplyChatTemplate( } // Encode converts a string into token IDs. -func (t *HFCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (t *HFCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { ctx := context.TODO() req.IsLocal = false @@ -365,7 +364,7 @@ func (t *HFCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, return tokens, offsets, nil } -func (t *LocalCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (t *LocalCachedTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { ctx := context.TODO() req.IsLocal = true @@ -450,7 +449,7 @@ func (c *CompositeTokenizer) ApplyChatTemplate( // 4. If all fail, returns all accumulated errors // // This enables prioritizing local tokenizers while maintaining HuggingFace as a fallback. -func (c *CompositeTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (c *CompositeTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { var rErr error for _, tokenizer := range c.Tokenizers { copiedReq, err := req.DeepCopy() diff --git a/pkg/tokenization/tokenizer_test.go b/pkg/tokenization/tokenizer_test.go index 17cabc586..0e6ccedb7 100644 --- a/pkg/tokenization/tokenizer_test.go +++ b/pkg/tokenization/tokenizer_test.go @@ -23,7 +23,6 @@ import ( "path/filepath" "testing" - "github.com/daulet/tokenizers" preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,11 +41,11 @@ func (d *DummyTokenizer) ApplyChatTemplate( return prompt, nil } -func (d *DummyTokenizer) Encode(*preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (d *DummyTokenizer) Encode(*preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { if d.returnError { return nil, nil, fmt.Errorf("dummy tokenizer error") } - return []uint32{1, 2, 3}, []tokenizers.Offset{{0, 1}, {2, 3}, {4, 5}}, nil + return []uint32{1, 2, 3}, []preprocessing.Offset{{0, 1}, {2, 3}, {4, 5}}, nil } func (d *DummyTokenizer) Type() string { diff --git a/pkg/tokenization/uds_tokenizer.go b/pkg/tokenization/uds_tokenizer.go index 405d3a0bb..9970cfc75 100644 --- a/pkg/tokenization/uds_tokenizer.go +++ b/pkg/tokenization/uds_tokenizer.go @@ -29,7 +29,6 @@ import ( "strings" "time" - "github.com/daulet/tokenizers" preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "golang.org/x/net/http2" ) @@ -52,8 +51,8 @@ type UdsTokenizer struct { // TokenizedInput represents the response from the tokenize endpoint. type TokenizedInput struct { - InputIDs []uint32 `json:"input_ids"` - OffsetMapping []tokenizers.Offset `json:"offset_mapping"` + InputIDs []uint32 `json:"input_ids"` + OffsetMapping []preprocessing.Offset `json:"offset_mapping"` } const ( @@ -105,7 +104,7 @@ func NewUdsTokenizer(config *UdsTokenizerConfig) (Tokenizer, error) { } // Encode tokenizes the input string and returns the token IDs and offsets. -func (u *UdsTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []tokenizers.Offset, error) { +func (u *UdsTokenizer) Encode(req *preprocessing.EncodeRequest) ([]uint32, []preprocessing.Offset, error) { httpReq, err := http.NewRequestWithContext( context.Background(), http.MethodPost, From e4be74e2cd86d8dd3a6fa9d10baea6f877617c81 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Fri, 19 Dec 2025 06:45:13 +0000 Subject: [PATCH 3/5] vllm-cpu --- pkg/preprocessing/chat_completions/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/preprocessing/chat_completions/requirements.txt b/pkg/preprocessing/chat_completions/requirements.txt index 88f4bf039..6f48e6c24 100644 --- a/pkg/preprocessing/chat_completions/requirements.txt +++ b/pkg/preprocessing/chat_completions/requirements.txt @@ -1 +1,3 @@ -vllm>=0.11.0 \ No newline at end of file +--index-url https://download.pytorch.org/whl/cpu +--extra-index-url https://pypi.org/simple +vllm-cpu>=0.11.0 \ No newline at end of file From a0ec0747012d799a52d6a7d8f45c1f7f0d5faea2 Mon Sep 17 00:00:00 2001 From: Hyunkyun Moon Date: Fri, 19 Dec 2025 16:20:01 +0900 Subject: [PATCH 4/5] Update pkg/tokenization/tokenizer.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Hyunkyun Moon --- pkg/tokenization/tokenizer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index f5d9b4d7f..159eded1c 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -338,7 +338,7 @@ func (t *LocalCachedTokenizer) ApplyChatTemplate( req.IsLocal = true path, ok := t.localTokenizerConfig.ModelTokenizerMap[req.Model] if !ok { - return "", fmt.Errorf("tokenizer for model %q not found", modelName) + return "", fmt.Errorf("tokenizer for model %q not found", req.Model) } req.Model = filepath.Dir(path) res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) From b198108724269203b70ba18c1bb6a48cf8455bb7 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Fri, 19 Dec 2025 07:39:17 +0000 Subject: [PATCH 5/5] apply copilot review --- pkg/preprocessing/chat_completions/cgo_functions.go | 1 - .../chat_completions/tokenizer_wrapper.py | 13 ++++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index 63818b46f..599147598 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -49,7 +49,6 @@ type ChatTemplateRequest struct { // ApplyChatTemplateRequest represents the request to render a chat template. type ApplyChatTemplateRequest struct { - // `conversation` is the transformers name, but we use `messages` for consistency with OpenAI API. // The Python wrapper will handle converting this to a batched list if needed. ChatTemplateRequest Conversation []Conversation `json:"conversation"` diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index 7496f3816..c167a1dae 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -21,7 +21,6 @@ import logging import os import sys -from typing import Optional, Union from vllm.transformers_utils.tokenizer import get_tokenizer # Basic logging setup @@ -58,18 +57,17 @@ def apply_chat_template(request_json): - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). - revision (str, optional): Model revision. - token (str, optional): Hugging Face token for private models. - - conversations (list): List of conversation lists + - conversation (list): List of conversation lists - chat_template (str, optional): The template to use - tools (list, optional): Tool schemas - documents (list, optional): Document schemas - return_assistant_tokens_mask (bool, optional): Whether to return assistant tokens mask - continue_final_message (bool, optional): Whether to continue final message - add_generation_prompt (bool, optional): Whether to add generation prompt - - kwargs (dict, optional): Additional rendering variables - - isVLLM / - isSGLang enum - + - chat_template_kwargs (dict, optional): Additional rendering variables + Returns: - str: JSON string containing 'rendered_chats' and 'generation_indices' keys. + str: The rendered chat template as a string. """ try: @@ -118,9 +116,6 @@ def encode(request_json: str) -> str: - download_dir (str, optional): Directory to download the model. - text (str): The text to encode. - token (str, optional): Hugging Face token for private models. - - isVLLM (bool, optional): Whether to use VLLM tokenizer. - - isSGLang (bool, optional): Whether to use SG-Lang tokenizer. - - .... Returns: str: JSON string containing 'encoded_texts' key with list of token ID lists.