From be99da036786acf10a208300e0dcddbb81c91f32 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Sat, 27 Dec 2025 08:11:56 +0000 Subject: [PATCH 01/12] apply chat template Signed-off-by: HyunKyun Moon --- examples/kv_events/online/main.go | 25 +- examples/testdata/data.go | 2 +- go.mod | 2 +- pkg/kvcache/indexer.go | 2 +- .../chat_completions/cgo_functions.c | 275 ++++++++-------- .../chat_completions/cgo_functions.go | 143 +++------ .../chat_completions/cgo_functions.h | 19 +- .../chat_completions/cgo_functions_test.go | 303 +++++++----------- .../render_jinja_template_wrapper.py | 259 --------------- .../chat_completions/requirements.txt | 10 +- .../chat_completions/tokenizer_wrapper.py | 153 +++++++++ pkg/tokenization/pool.go | 10 +- pkg/tokenization/pool_test.go | 18 +- pkg/tokenization/tokenizer.go | 141 ++++---- pkg/tokenization/tokenizer_test.go | 16 +- pkg/tokenization/uds_tokenizer.go | 10 +- tests/e2e/redis_mock/e2e_suite_test.go | 4 +- tests/e2e/redis_mock/e2e_test.go | 147 ++++----- 18 files changed, 659 insertions(+), 880 deletions(-) delete mode 100644 pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py create mode 100644 pkg/preprocessing/chat_completions/tokenizer_wrapper.py diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index 87f09cc96..f141664aa 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -62,7 +62,7 @@ const ( // ChatCompletionsRequest holds the fields needed for chat-completions rendering. type ChatCompletionsRequest struct { Model string `json:"model"` - *preprocessing.RenderJinjaTemplateRequest + *preprocessing.ApplyChatTemplateRequest } func main() { @@ -320,35 +320,18 @@ func setupUnifiedHTTPEndpoints( logger.Info("Created ChatCompletions", "req", req) - // Get chat template for the model if not provided - if req.ChatTemplate == "" { - templateReq := preprocessing.FetchChatTemplateRequest{ - Model: req.Model, - Token: os.Getenv(envHFToken), - } - - var err error - req.ChatTemplate, req.ChatTemplateKWArgs, err = chatTemplatingProcessor.FetchChatTemplate(ctx, templateReq) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to get chat template: %v", err), http.StatusInternalServerError) - return - } - } - - response, err := chatTemplatingProcessor.RenderChatTemplate(ctx, req.RenderJinjaTemplateRequest) + renderedPrompt, err := chatTemplatingProcessor.ApplyChatTemplate(ctx, req.ApplyChatTemplateRequest) if err != nil { http.Error(w, fmt.Sprintf("Failed to render chat template: %v", err), http.StatusInternalServerError) return } // Use KV-cache to score the rendered template - if len(response.RenderedChats) == 0 { - http.Error(w, "No rendered chats found in response", http.StatusInternalServerError) + if len(renderedPrompt) == 0 { + http.Error(w, "rendered prompt is empty", http.StatusInternalServerError) return } - renderedPrompt := response.RenderedChats[0] - // Get score pods, err := kvCacheIndexer.GetPodScores(ctx, nil, renderedPrompt, req.Model, nil) if err != nil { diff --git a/examples/testdata/data.go b/examples/testdata/data.go index 1806b972c..cdc5e6f35 100644 --- a/examples/testdata/data.go +++ b/examples/testdata/data.go @@ -24,7 +24,7 @@ const ( ModelName = "bert-base-uncased" ) -var RenderReq *preprocessing.RenderJinjaTemplateRequest = nil +var RenderReq *preprocessing.ApplyChatTemplateRequest = nil //go:embed prompt.txt var Prompt string diff --git a/go.mod b/go.mod index 1940cbcbd..2b87f9ec5 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/vmihailenco/msgpack/v5 v5.4.1 go.uber.org/multierr v1.11.0 + go.uber.org/zap v1.27.0 golang.org/x/net v0.38.0 google.golang.org/grpc v1.68.1 google.golang.org/protobuf v1.36.5 @@ -53,7 +54,6 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect - go.uber.org/zap v1.27.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.30.0 // indirect diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index d0642c483..ebf5b5fe8 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -134,7 +134,7 @@ func (k *Indexer) KVBlockIndex() kvblock.Index { // relevant. // // The function returns a map of pod identifiers to scores. -func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.RenderJinjaTemplateRequest, prompt, modelName string, +func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.ApplyChatTemplateRequest, prompt, modelName string, podIdentifiers []string, ) (map[string]float64, error) { traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores") diff --git a/pkg/preprocessing/chat_completions/cgo_functions.c b/pkg/preprocessing/chat_completions/cgo_functions.c index 9eef79151..b0b27a82d 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.c +++ b/pkg/preprocessing/chat_completions/cgo_functions.c @@ -14,20 +14,21 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include // for getpid() and usleep() +#include // for getpid() and usleep() +#include #include "cgo_functions.h" // Global variables for caching PyObject* g_chat_template_module = NULL; -PyObject* g_render_jinja_template_func = NULL; -PyObject* g_get_model_chat_template_func = NULL; +PyObject* g_load_tokenizer_with_cache_func = NULL; +PyObject* g_apply_chat_template_func = NULL; int g_initialized = 0; int g_python_initialized = 0; // Process-level global initialization tracking static int g_process_initialized = 0; -static int g_finalized = 0; +static int g_finalized = 0; static pid_t g_init_pid = 0; // Thread safety for initialization @@ -42,7 +43,8 @@ int Py_InitializeGo() { if (g_process_initialized) { if (g_init_pid != getpid()) { printf("[C] Py_InitializeGo WARNING - Different PID trying to initialize (init_pid: %d, current_pid: %d)\n", g_init_pid, getpid()); - } else { + } + else { printf("[C] Py_InitializeGo - Already initialized in this process (PID: %d)\n", getpid()); } return 0; @@ -95,24 +97,25 @@ void Py_FinalizeGo() { printf("[C] Py_FinalizeGo - Already finalized, skipping\n"); return; } - + // Mark as finalized first to prevent race conditions g_finalized = 1; - + // Clean up module references safely - if (g_render_jinja_template_func) { - Py_DECREF(g_render_jinja_template_func); - g_render_jinja_template_func = NULL; + if (g_load_tokenizer_with_cache_func) { + Py_DECREF(g_load_tokenizer_with_cache_func); + g_load_tokenizer_with_cache_func = NULL; } - if (g_get_model_chat_template_func) { - Py_DECREF(g_get_model_chat_template_func); - g_get_model_chat_template_func = NULL; + + if (g_apply_chat_template_func) { + Py_DECREF(g_apply_chat_template_func); + g_apply_chat_template_func = NULL; } if (g_chat_template_module) { Py_DECREF(g_chat_template_module); g_chat_template_module = NULL; } - + // Reset state without finalizing Python // Python will be cleaned up when the process exits g_python_initialized = 0; @@ -149,7 +152,7 @@ const char* PyUnicode_AsGoString(PyObject* obj) { // Initialize the cached module and functions (call once at startup) int Py_InitChatTemplateModule() { - + // Thread-safe initialization check if (g_init_lock == NULL) { g_init_lock = PyThread_allocate_lock(); @@ -158,38 +161,38 @@ int Py_InitChatTemplateModule() { return -1; } } - + PyThread_acquire_lock(g_init_lock, NOWAIT_LOCK); - + // Check if already initialized if (g_initialized) { printf("[C] Py_InitChatTemplateModule - Already initialized globally, returning\n"); PyThread_release_lock(g_init_lock); return 0; } - + // Ensure Python is initialized if (!g_python_initialized) { printf("[C] Py_InitChatTemplateModule ERROR - Python not initialized\n"); PyThread_release_lock(g_init_lock); return -1; } - + // Acquire GIL for module initialization PyGILState_STATE gil_state = PyGILState_Ensure(); - - + + // Import the chat template wrapper module AFTER setting up the path - g_chat_template_module = PyImport_ImportModule("render_jinja_template_wrapper"); + g_chat_template_module = PyImport_ImportModule("tokenizer_wrapper"); if (!g_chat_template_module) { - printf("[C] Py_InitChatTemplateModule ERROR - Failed to import render_jinja_template_wrapper module\n"); + printf("[C] Py_InitChatTemplateModule ERROR - Failed to import tokenizer_wrapper module\n"); PyErr_Print(); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - + // Get the module dictionary PyObject* module_dict = PyModule_GetDict(g_chat_template_module); if (!module_dict) { @@ -198,30 +201,30 @@ int Py_InitChatTemplateModule() { PyThread_release_lock(g_init_lock); return -1; } - - // Get the render_jinja_template function - g_render_jinja_template_func = PyDict_GetItemString(module_dict, "render_jinja_template"); - if (!g_render_jinja_template_func || !PyCallable_Check(g_render_jinja_template_func)) { - printf("[C] Py_InitChatTemplateModule ERROR - render_jinja_template function not found or not callable\n"); + + // Get the load_tokenizer_with_cache function + g_load_tokenizer_with_cache_func = PyDict_GetItemString(module_dict, "load_tokenizer_with_cache"); + if (!g_load_tokenizer_with_cache_func || !PyCallable_Check(g_load_tokenizer_with_cache_func)) { + printf("[C] Py_InitChatTemplateModule ERROR - load_tokenizer_with_cache function not found or not callable\n"); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - Py_INCREF(g_render_jinja_template_func); // Keep a reference - - // Get the get_model_chat_template function - g_get_model_chat_template_func = PyDict_GetItemString(module_dict, "get_model_chat_template"); - if (!g_get_model_chat_template_func || !PyCallable_Check(g_get_model_chat_template_func)) { - printf("[C] Py_InitChatTemplateModule ERROR - get_model_chat_template function not found or not callable\n"); + Py_INCREF(g_load_tokenizer_with_cache_func); // Keep a reference + + // Get the apply_chat_template function + g_apply_chat_template_func = PyDict_GetItemString(module_dict, "apply_chat_template"); + if (!g_apply_chat_template_func || !PyCallable_Check(g_apply_chat_template_func)) { + printf("[C] Py_InitChatTemplateModule ERROR - apply_chat_template function not found or not callable\n"); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - Py_INCREF(g_get_model_chat_template_func); // Keep a reference - + Py_INCREF(g_apply_chat_template_func); // Keep a reference + // Release GIL PyGILState_Release(gil_state); - + g_initialized = 1; PyThread_release_lock(g_init_lock); return 0; @@ -229,38 +232,50 @@ int Py_InitChatTemplateModule() { -// Call the cached render_jinja_template function -char* Py_CallRenderJinjaTemplate(const char* json_request) { - // Try direct call first (fast path) - char* result = Py_CallRenderJinjaTemplateInternal(json_request); - if (result != NULL) { - return result; // Success on first try - } - - // If failed, just return NULL (no retry, no reload) - return NULL; +// Call the cached load_tokenizer_with_cache function +bool Py_CallLoadTokenizerWithCache(const char* json_request) { + return Py_CallLoadTokenizerWithCacheInternal(json_request); } // Internal function that does the actual work -char* Py_CallRenderJinjaTemplateInternal(const char* json_request) { - // Check if Python interpreter is still valid - if (!Py_IsInitialized()) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Python interpreter not initialized\n"); +bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { + // Check if Python interpreter is initialized + if (!g_python_initialized) { + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Python not initialized\n"); + fflush(stdout); return NULL; } - - // Simple validation + + // Validate cached function + if (!g_load_tokenizer_with_cache_func) { + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is NULL\n"); + fflush(stdout); + return NULL; + } + + // Validate that the cached function is still a valid Python object + fflush(stdout); + if (!PyCallable_Check(g_load_tokenizer_with_cache_func)) { + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is not callable (corrupted?)\n"); + fflush(stdout); + return NULL; + } + + // Validate input if (!json_request) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Input is NULL\n"); + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Input is NULL\n"); + fflush(stdout); return NULL; } - + // Acquire GIL for Python operations - PyGILState_STATE gil_state = PyGILState_Ensure(); + PyGILState_STATE gil_state = PyGILState_Ensure(); + // Create Python string from JSON request PyObject* py_json = PyUnicode_FromString(json_request); if (!py_json) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to create Python string\n"); + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Failed to create Python string\n"); + fflush(stdout); PyGILState_Release(gil_state); return NULL; } @@ -268,134 +283,114 @@ char* Py_CallRenderJinjaTemplateInternal(const char* json_request) { // Create arguments tuple PyObject* args = PyTuple_Pack(1, py_json); if (!args) { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to create args tuple\n"); + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Failed to create args tuple\n"); + fflush(stdout); Py_DECREF(py_json); PyGILState_Release(gil_state); return NULL; - } + } // Call the cached function - PyObject* py_result = PyObject_CallObject(g_render_jinja_template_func, args); - + PyObject* py_result = PyObject_CallObject(g_load_tokenizer_with_cache_func, args); + // Clean up args Py_DECREF(args); Py_DECREF(py_json); - - char* cresult = NULL; - if (py_result) { - // Convert to C string - const char* s = PyUnicode_AsUTF8(py_result); - if (s) { - cresult = strdup(s); - } else { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Failed to convert result to C string\n"); - } - Py_DECREF(py_result); - } else { - printf("[C] Py_CallRenderJinjaTemplateInternal ERROR - Python function returned NULL\n"); + + bool cresult = true; + if (!py_result) { + printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Python function returned NULL\n"); + fflush(stdout); PyErr_Print(); fflush(stderr); + cresult = false; } - + // Release GIL PyGILState_Release(gil_state); - + return cresult; } -// Call the cached get_model_chat_template function -char* Py_CallGetModelChatTemplate(const char* json_request) { +// Call the cached apply_chat_template function +char* Py_CallApplyChatTemplate(const char* json_request) { // Try direct call first (fast path) - char* result = Py_CallGetModelChatTemplateInternal(json_request); + char* result = Py_CallApplyChatTemplateInternal(json_request); if (result != NULL) { return result; // Success on first try } - + // If failed, just return NULL (no retry, no reload) return NULL; } // Internal function that does the actual work -char* Py_CallGetModelChatTemplateInternal(const char* json_request) { - // Check if Python is initialized - if (!g_python_initialized) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Python not initialized\n"); - fflush(stdout); - return NULL; - } - - // Validate cached function - if (!g_get_model_chat_template_func) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Cached function is NULL\n"); - fflush(stdout); - return NULL; - } - - // Validate that the cached function is still a valid Python object - fflush(stdout); - if (!PyCallable_Check(g_get_model_chat_template_func)) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Cached function is not callable (corrupted?)\n"); +char* Py_CallApplyChatTemplateInternal(const char* json_request) { + // Check if Python interpreter is still valid + if (!Py_IsInitialized()) { + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Python interpreter not initialized\n"); fflush(stdout); return NULL; } - - // Validate input + + // Simple validation if (!json_request) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Input is NULL\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Input is NULL\n"); fflush(stdout); return NULL; } - + // Acquire GIL for Python operations PyGILState_STATE gil_state = PyGILState_Ensure(); - // Create Python string from JSON request PyObject* py_json = PyUnicode_FromString(json_request); if (!py_json) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to create Python string\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to create Python string\n"); fflush(stdout); PyGILState_Release(gil_state); return NULL; } - + // Create arguments tuple PyObject* args = PyTuple_Pack(1, py_json); if (!args) { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to create args tuple\n"); + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to create args tuple\n"); fflush(stdout); Py_DECREF(py_json); PyGILState_Release(gil_state); return NULL; } - + // Call the cached function - PyObject* py_result = PyObject_CallObject(g_get_model_chat_template_func, args); - + PyObject* py_result = PyObject_CallObject(g_apply_chat_template_func, args); + // Clean up args Py_DECREF(args); Py_DECREF(py_json); - + char* cresult = NULL; if (py_result) { // Convert to C string const char* s = PyUnicode_AsUTF8(py_result); if (s) { cresult = strdup(s); - } else { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Failed to convert result to C string\n"); + } + else { + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Failed to convert result to C string\n"); fflush(stdout); } Py_DECREF(py_result); - } else { - printf("[C] Py_CallGetModelChatTemplateInternal ERROR - Python function returned NULL\n"); + } + else { + printf("[C] Py_CallApplyChatTemplateInternal ERROR - Python function returned NULL\n"); fflush(stdout); PyErr_Print(); fflush(stderr); } - + // Release GIL PyGILState_Release(gil_state); - + return cresult; } @@ -405,9 +400,9 @@ char* Py_ClearCaches() { printf("[C] Py_ClearCaches ERROR - Module not initialized\n"); return NULL; } - + PyGILState_STATE gil_state = PyGILState_Ensure(); - + // Call the clear_caches function PyObject* clear_caches_func = PyDict_GetItemString(PyModule_GetDict(g_chat_template_module), "clear_caches"); if (!clear_caches_func || !PyCallable_Check(clear_caches_func)) { @@ -415,7 +410,7 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + PyObject* result = PyObject_CallObject(clear_caches_func, NULL); if (!result) { printf("[C] Py_ClearCaches ERROR - Failed to call clear_caches function\n"); @@ -423,7 +418,7 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + // Convert result to C string const char* result_str = PyUnicode_AsUTF8(result); if (!result_str) { @@ -432,11 +427,11 @@ char* Py_ClearCaches() { PyGILState_Release(gil_state); return NULL; } - + char* c_result = strdup(result_str); Py_DECREF(result); PyGILState_Release(gil_state); - + return c_result; } @@ -444,50 +439,50 @@ char* Py_ClearCaches() { void Py_CleanupChatTemplateModule() { if (g_initialized && Py_IsInitialized()) { PyGILState_STATE state = PyGILState_Ensure(); - Py_XDECREF(g_render_jinja_template_func); - Py_XDECREF(g_get_model_chat_template_func); + Py_XDECREF(g_load_tokenizer_with_cache_func); + Py_XDECREF(g_apply_chat_template_func); Py_XDECREF(g_chat_template_module); - g_render_jinja_template_func = NULL; - g_get_model_chat_template_func = NULL; + g_load_tokenizer_with_cache_func = NULL; + g_apply_chat_template_func = NULL; g_chat_template_module = NULL; g_initialized = 0; PyGILState_Release(state); } -} +} // Re-initialize Python interpreter state -int Py_ReinitializeGo() { +int Py_ReinitializeGo() { // Reset global flags g_initialized = 0; g_python_initialized = 0; g_process_initialized = 0; - + // Clean up cached objects - if (g_render_jinja_template_func) { - Py_DECREF(g_render_jinja_template_func); - g_render_jinja_template_func = NULL; + if (g_load_tokenizer_with_cache_func) { + Py_DECREF(g_load_tokenizer_with_cache_func); + g_load_tokenizer_with_cache_func = NULL; } - if (g_get_model_chat_template_func) { - Py_DECREF(g_get_model_chat_template_func); - g_get_model_chat_template_func = NULL; + if (g_apply_chat_template_func) { + Py_DECREF(g_apply_chat_template_func); + g_apply_chat_template_func = NULL; } if (g_chat_template_module) { Py_DECREF(g_chat_template_module); g_chat_template_module = NULL; } - + // Re-initialize int result = Py_InitializeGo(); if (result != 0) { printf("[C] Py_ReinitializeGo ERROR - Failed to re-initialize Python\n"); return result; } - + result = Py_InitChatTemplateModule(); if (result != 0) { printf("[C] Py_ReinitializeGo ERROR - Failed to re-initialize chat template module\n"); return result; } - + return 0; } diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index 68752d1c1..57b7c09aa 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -33,33 +33,41 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -// ChatMessage represents a single message in a conversation. -type ChatMessage struct { +type LoadTokenizerWithCacheRequest struct { + IsLocal bool `json:"is_local,omitempty"` + DownloadDir string `json:"download_dir,omitempty"` + Model string `json:"model"` + Revision string `json:"revision,omitempty"` + Token string `json:"token,omitempty"` +} + +// Conversation represents a single message in a conversation. +type Conversation struct { Role string `json:"role"` Content string `json:"content"` } -// RenderJinjaTemplateRequest represents the request to render a chat template. -type RenderJinjaTemplateRequest struct { - // `conversations` is the transformers name, but we use `messages` for consistency with OpenAI API. +// ApplyChatTemplateRequest represents the request to render a chat template. +type ApplyChatTemplateRequest struct { // The Python wrapper will handle converting this to a batched list if needed. - Conversations []ChatMessage `json:"messages"` - Tools []interface{} `json:"tools,omitempty"` - Documents []interface{} `json:"documents,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"` - ContinueFinalMessage bool `json:"continue_final_message,omitempty"` - AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"` - ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` + LoadTokenizerWithCacheRequest LoadTokenizerWithCacheRequest `json:"load_tokenizer_with_cache_request,omitempty"` + Conversation []Conversation `json:"conversation"` + Tools []interface{} `json:"tools,omitempty"` + Documents []interface{} `json:"documents,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"` + ContinueFinalMessage bool `json:"continue_final_message,omitempty"` + AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"` + ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` } -// DeepCopy creates a deep copy of the RenderJinjaTemplateRequest. -func (req *RenderJinjaTemplateRequest) DeepCopy() (*RenderJinjaTemplateRequest, error) { +// DeepCopy creates a deep copy of the ApplyChatTemplateRequest. +func (req *ApplyChatTemplateRequest) DeepCopy() (*ApplyChatTemplateRequest, error) { b, err := json.Marshal(req) if err != nil { return nil, err } - var out RenderJinjaTemplateRequest + var out ApplyChatTemplateRequest err = json.Unmarshal(b, &out) if err != nil { return nil, err @@ -67,35 +75,9 @@ func (req *RenderJinjaTemplateRequest) DeepCopy() (*RenderJinjaTemplateRequest, return &out, nil } -// RenderJinjaTemplateResponse represents the response from rendering a chat template. -type RenderJinjaTemplateResponse struct { - RenderedChats []string `json:"rendered_chats"` - GenerationIndices [][][]int `json:"generation_indices"` -} - -// FetchChatTemplateRequest represents the request to fetch a chat template. -// This is needed if the fields are not set in the `RenderJinjaTemplateRequest`. -// When called, it will fetch the `chat_template` from the tokenizer. -// If the tokenizer is not present, it will be fetched from HuggingFace using -// the `token` if provided. -type FetchChatTemplateRequest struct { - Model string `json:"model"` - ChatTemplate string `json:"chat_template,omitempty"` - Tools []interface{} `json:"tools,omitempty"` - Revision string `json:"revision,omitempty"` - Token string `json:"token,omitempty"` - IsLocalPath bool `json:"is_local_path,omitempty"` -} - -// FetchChatTemplateResponse represents the response from fetching a chat template. -type FetchChatTemplateResponse struct { - ChatTemplate string `json:"chat_template,omitempty"` - ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` -} - // ChatTemplatingProcessor is a processor that handles chat template rendering // using a cached Python function. Once the Python interpreter is initialized, -// it caches the `transformers` function `render_jinja_template` for rendering +// it caches the `vllm` function `apply_chat_template` for rendering // chat templates. It also provides a method to fetch chat templates from the // tokenizer or HuggingFace if the tokenizer is not present. type ChatTemplatingProcessor struct{} @@ -128,76 +110,57 @@ func (w *ChatTemplatingProcessor) Finalize() { C.Py_FinalizeGo() } -// RenderChatTemplate renders a chat template using the cached Python function. -// It calls the Python `transformers` function `render_jinja_template` with the provided request. -// -//nolint:gocritic // hugeParam: req is passed by value intentionally for immutability, but can consider using pointer. -func (w *ChatTemplatingProcessor) RenderChatTemplate(ctx context.Context, - req *RenderJinjaTemplateRequest, -) (*RenderJinjaTemplateResponse, error) { - traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("RenderChatTemplate") +// Load Tokenzier. +func (w *ChatTemplatingProcessor) LoadTokenizerWithCache( + ctx context.Context, + req *LoadTokenizerWithCacheRequest, +) error { + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("LoadTokenizer") if req == nil { traceLogger.Error(nil, "Received nil request") - return nil, fmt.Errorf("received nil request") + return fmt.Errorf("received nil request") } - // Convert request to JSON reqJSON, err := json.Marshal(req) if err != nil { traceLogger.Error(err, "Failed to marshal request") - return nil, fmt.Errorf("failed to marshal request: %w", err) + return fmt.Errorf("failed to marshal request: %w", err) } // Call the cached Python function - cResult := C.Py_CallRenderJinjaTemplate(C.CString(string(reqJSON))) - if cResult == nil { - traceLogger.Error(nil, "C function returned nil") - return nil, fmt.Errorf("python render_jinja_template failed") + cResult := C.Py_CallLoadTokenizerWithCache(C.CString(string(reqJSON))) + if !cResult { + traceLogger.Error(nil, "C function returned false") + return fmt.Errorf("python load tokenizer failed") } - defer C.free(unsafe.Pointer(cResult)) - resultJSON := C.GoString(cResult) - - // Parse the response - var response RenderJinjaTemplateResponse - if err := json.Unmarshal([]byte(resultJSON), &response); err != nil { - traceLogger.Error(err, "Failed to unmarshal response") - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &response, nil + return nil } -// FetchChatTemplate fetches the model chat template using the cached Python function. -// -//nolint:gocritic // hugeParam: req is passed by value intentionally for immutability, but can consider using pointer. -func (w *ChatTemplatingProcessor) FetchChatTemplate( - ctx context.Context, - req FetchChatTemplateRequest, -) (string, map[string]interface{}, error) { - traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("FetchChatTemplate") +// ApplyChatTemplate renders a chat template using the cached Python function. +// It calls the Python `vllm` function `apply_chat_template` with the provided request. +func (w *ChatTemplatingProcessor) ApplyChatTemplate(ctx context.Context, + req *ApplyChatTemplateRequest, +) (string, error) { + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("ApplyChatTemplate") + + if req == nil { + traceLogger.Error(nil, "Received nil request") + return "", fmt.Errorf("received nil request") + } - // Convert request to JSON reqJSON, err := json.Marshal(req) if err != nil { traceLogger.Error(err, "Failed to marshal request") - return "", nil, fmt.Errorf("failed to marshal request: %w", err) + return "", fmt.Errorf("failed to marshal request: %w", err) } // Call the cached Python function - cResult := C.Py_CallGetModelChatTemplate(C.CString(string(reqJSON))) + cResult := C.Py_CallApplyChatTemplate(C.CString(string(reqJSON))) if cResult == nil { traceLogger.Error(nil, "C function returned nil") - return "", nil, fmt.Errorf("python get_model_chat_template failed") + return "", fmt.Errorf("python apply_chat_template failed") } defer C.free(unsafe.Pointer(cResult)) - resultJSON := C.GoString(cResult) - - // Parse the response - var response FetchChatTemplateResponse - if err := json.Unmarshal([]byte(resultJSON), &response); err != nil { - traceLogger.Error(err, "Failed to unmarshal response") - return "", nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - return response.ChatTemplate, response.ChatTemplateKWArgs, nil + return C.GoString(cResult), nil } // ClearCaches clears all caches for testing purposes. diff --git a/pkg/preprocessing/chat_completions/cgo_functions.h b/pkg/preprocessing/chat_completions/cgo_functions.h index 91b454c2f..45ba7bf2a 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.h +++ b/pkg/preprocessing/chat_completions/cgo_functions.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include // === FUNCTION DECLARATIONS === @@ -48,23 +49,23 @@ const char* PyUnicode_AsGoString(PyObject* obj); // Global variables to hold cached module and functions extern PyObject* g_chat_template_module; -extern PyObject* g_render_jinja_template_func; -extern PyObject* g_get_model_chat_template_func; +extern PyObject* g_load_tokenizer_with_cache_func; +extern PyObject* g_apply_chat_template_func; // Initialize the cached module and functions (call once at startup) int Py_InitChatTemplateModule(); -// Call the cached render_jinja_template function -char* Py_CallRenderJinjaTemplate(const char* json_request); +// Call the cached load_tokenizer_with_cache function +bool Py_CallLoadTokenizerWithCache(const char* json_request); // Internal function that does the actual work -char* Py_CallRenderJinjaTemplateInternal(const char* json_request); +bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request); -// Call the cached get_model_chat_template function -char* Py_CallGetModelChatTemplate(const char* json_request); +// Call the cached apply_chat_template function +char* Py_CallApplyChatTemplate(const char* json_request); // Internal function that does the actual work -char* Py_CallGetModelChatTemplateInternal(const char* json_request); +char* Py_CallApplyChatTemplateInternal(const char* json_request); // Clear all caches for testing purposes char* Py_ClearCaches(void); @@ -75,4 +76,4 @@ void Py_CleanupChatTemplateModule(); // Re-initialize Python interpreter state int Py_ReinitializeGo(); -#endif // CGO_FUNCTIONS_H \ No newline at end of file +#endif // CGO_FUNCTIONS_H \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/cgo_functions_test.go b/pkg/preprocessing/chat_completions/cgo_functions_test.go index 19a9048f9..d626c0a25 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions_test.go +++ b/pkg/preprocessing/chat_completions/cgo_functions_test.go @@ -35,7 +35,6 @@ import ( var ( globalWrapper *preprocessing.ChatTemplatingProcessor globalWrapperOnce sync.Once - globalWrapperMu sync.Mutex ) // getGlobalWrapper returns a singleton wrapper instance. @@ -50,8 +49,8 @@ func getGlobalWrapper() *preprocessing.ChatTemplatingProcessor { return globalWrapper } -// TestGetModelChatTemplate tests the get_model_chat_template function. -func TestGetModelChatTemplate(t *testing.T) { +// TestLoadTokenizerWithCache tests the load_tokenizer_with_cache function. +func TestLoadTokenizerWithCache(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements @@ -79,7 +78,7 @@ func TestGetModelChatTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := preprocessing.FetchChatTemplateRequest{ + request := &preprocessing.LoadTokenizerWithCacheRequest{ Model: tt.modelName, Revision: tt.revision, Token: tt.token, @@ -87,18 +86,14 @@ func TestGetModelChatTemplate(t *testing.T) { // Profile the function call start := time.Now() - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + err := wrapper.LoadTokenizerWithCache(context.Background(), request) duration := time.Since(start) // Log performance - t.Logf("Model: %s, Duration: %v, ChatTemplate length: %d", tt.modelName, duration, len(template)) - + t.Logf("Model: %s, Duration: %v", tt.modelName, duration) if tt.expectTemplate { // Models that should have templates - require.NoError(t, err, "FetchChatTemplate should not return an error") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") - assert.Contains(t, template, "messages", "ChatTemplate should contain messages") + require.NoError(t, err, "LoadTokenizerWithCache should not return an error") } else { // Models that don't have chat templates if err != nil { @@ -112,8 +107,8 @@ func TestGetModelChatTemplate(t *testing.T) { } } -// TestRenderJinjaTemplate tests the render_jinja_template function. -func TestRenderJinjaTemplate(t *testing.T) { +// TestApplyChatTemplate tests the render_jinja_template function. +func TestApplyChatTemplate(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements @@ -140,12 +135,12 @@ func TestRenderJinjaTemplate(t *testing.T) { tests := []struct { name string template string - messages []preprocessing.ChatMessage + messages []preprocessing.Conversation }{ { name: "Simple ChatTemplate", template: simpleTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, }, @@ -153,7 +148,7 @@ func TestRenderJinjaTemplate(t *testing.T) { { name: "Complex ChatTemplate with System Message", template: complexTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "system", Content: "You are a helpful AI assistant."}, {Role: "user", Content: "What is the weather like?"}, {Role: "assistant", Content: "I don't have access to real-time weather data."}, @@ -162,7 +157,7 @@ func TestRenderJinjaTemplate(t *testing.T) { { name: "Complex ChatTemplate without System Message", template: complexTemplate, - messages: []preprocessing.ChatMessage{ + messages: []preprocessing.Conversation{ {Role: "user", Content: "Tell me a joke"}, {Role: "assistant", Content: "Why don't scientists trust atoms? Because they make up everything!"}, }, @@ -171,26 +166,28 @@ func TestRenderJinjaTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: tt.messages, - ChatTemplate: tt.template, + request := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: "ibm-granite/granite-3.3-8b-instruct", + IsLocal: true, + }, + Conversation: tt.messages, + ChatTemplate: tt.template, } // Profile the function call start := time.Now() - response, err := wrapper.RenderChatTemplate(context.Background(), request) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), request) duration := time.Since(start) // Assertions - require.NoError(t, err, "RenderChatTemplate should not return an error") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") + require.NoError(t, err, "ApplyChatTemplate should not return an error") + assert.NotEmpty(t, rendered, "rendered should not be empty") // Log performance - t.Logf("ChatTemplate: %s, Duration: %v, Rendered length: %d", tt.name, duration, len(response.RenderedChats[0])) + t.Logf("ChatTemplate: %s, Duration: %v, Rendered length: %d", tt.name, duration, len(rendered)) // Verify rendered content - rendered := response.RenderedChats[0] for _, message := range tt.messages { // For complex templates, the role might not be explicitly shown in output // but the content should always be present @@ -205,8 +202,8 @@ func TestRenderJinjaTemplate(t *testing.T) { } } -// TestTemplateCaching tests the caching functionality. -func TestTemplateCaching(t *testing.T) { +// TestLoadTokenizerCaching tests the caching functionality. +func TestLoadTokenizerCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear all caches to ensure we start with a clean state @@ -214,28 +211,25 @@ func TestTemplateCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") modelName := "ibm-granite/granite-3.3-8b-instruct" - request := preprocessing.FetchChatTemplateRequest{ - Model: modelName, + request := &preprocessing.LoadTokenizerWithCacheRequest{ + Model: modelName, + IsLocal: false, } // First call - should be cache miss t.Log("=== First call (Cache MISS) ===") start := time.Now() - template1, vars1, err := wrapper.FetchChatTemplate(context.Background(), request) + err = wrapper.LoadTokenizerWithCache(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - should be cache hit t.Log("=== Second call (Cache HIT) ===") start = time.Now() - template2, vars2, err := wrapper.FetchChatTemplate(context.Background(), request) + err = wrapper.LoadTokenizerWithCache(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") - // Verify results are identical - assert.Equal(t, template1, template2, "Cached and non-cached results should be identical") - assert.Equal(t, vars1, vars2, "Cached and non-cached vars should be identical") - // Verify performance improvement t.Logf("First call duration: %v, Second call duration: %v, Speedup: %.1fx", duration1, duration2, float64(duration1)/float64(duration2)) @@ -255,13 +249,13 @@ func TestChatCompletionsIntegration(t *testing.T) { tests := []struct { name string modelName string - conversation []preprocessing.ChatMessage + conversation []preprocessing.Conversation description string }{ { name: "Simple Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the capital of France?"}, {Role: "assistant", Content: "The capital of France is Paris."}, }, @@ -270,7 +264,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "Multi-turn Conversation", modelName: "microsoft/DialoGPT-medium", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello, how are you?"}, {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, {Role: "user", Content: "Can you tell me about machine learning?"}, @@ -282,7 +276,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "System Message Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "system", Content: "You are a helpful AI assistant specialized in coding."}, {Role: "user", Content: "Write a Python function to calculate fibonacci numbers."}, {Role: "assistant", Content: "Here's a Python function to calculate fibonacci numbers:\n" + @@ -293,7 +287,7 @@ func TestChatCompletionsIntegration(t *testing.T) { { name: "Simple Conversation (Repeated)", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.ChatMessage{ + conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the capital of France?"}, {Role: "assistant", Content: "The capital of France is Paris."}, }, @@ -305,31 +299,19 @@ func TestChatCompletionsIntegration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Logf("Testing: %s - %s", tt.name, tt.description) - // Step 1: Get the model's chat template + // Step 1: Render the conversation using the template start := time.Now() - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: tt.modelName, + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: tt.modelName, + }, + Conversation: tt.conversation, } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - templateDuration := time.Since(start) - require.NoError(t, err, "Failed to get model chat template") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - - // Step 2: Render the conversation using the template - start = time.Now() - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: tt.conversation, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, - } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render chat template") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") // Step 3: Verify the rendered output - rendered := response.RenderedChats[0] assert.NotEmpty(t, rendered, "Rendered chat should not be empty") // Verify all conversation messages are present in the rendered output @@ -338,8 +320,7 @@ func TestChatCompletionsIntegration(t *testing.T) { } // Log performance metrics - t.Logf("ChatTemplate fetch duration: %v, Render duration: %v, Total duration: %v", - templateDuration, renderDuration, templateDuration+renderDuration) + t.Logf("ChatTemplate Render duration: %v", renderDuration) }) } } @@ -379,7 +360,7 @@ func TestLongChatCompletions(t *testing.T) { require.NoError(t, err, "Failed to clear caches") // Create a long conversation - longConversation := []preprocessing.ChatMessage{ + longConversation := []preprocessing.Conversation{ {Role: "system", Content: "You are an expert software engineer with deep knowledge of Go, Python, " + "and system design. " + "Provide detailed, accurate responses."}, @@ -408,35 +389,25 @@ func TestLongChatCompletions(t *testing.T) { modelName := "ibm-granite/granite-3.3-8b-instruct" t.Run("Long Conversation Processing", func(t *testing.T) { - // Get template - start := time.Now() - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: modelName, - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - templateDuration := time.Since(start) - require.NoError(t, err, "Failed to get model chat template") - // Render long conversation - start = time.Now() - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: longConversation, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, + start := time.Now() + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: modelName, + }, + Conversation: longConversation, } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render long conversation") // Verify results - rendered := response.RenderedChats[0] assert.NotEmpty(t, rendered, "Long conversation should render successfully") assert.Greater(t, len(rendered), 1000, "Long conversation should produce substantial output") // Performance metrics - t.Logf("ChatTemplate fetch: %v, Long conversation render: %v, Total processing time: %v", - templateDuration, renderDuration, templateDuration+renderDuration) + t.Logf("ChatTemplate Long conversation render: %v", renderDuration) // Verify all messages are present for _, message := range longConversation { @@ -446,15 +417,15 @@ func TestLongChatCompletions(t *testing.T) { }) } -// BenchmarkGetModelChatTemplate benchmarks the template fetching performance. -func BenchmarkGetModelChatTemplate(b *testing.B) { +// BenchmarkLoadTokenizerWithCache benchmarks the template fetching performance. +func BenchmarkLoadTokenizerWithCache(b *testing.B) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements err := preprocessing.ClearCaches(context.Background()) require.NoError(b, err, "Failed to clear caches") - request := preprocessing.FetchChatTemplateRequest{ + request := &preprocessing.LoadTokenizerWithCacheRequest{ Model: "ibm-granite/granite-3.3-8b-instruct", } @@ -465,7 +436,7 @@ func BenchmarkGetModelChatTemplate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - _, _, err := wrapper.FetchChatTemplate(context.Background(), request) + err := wrapper.LoadTokenizerWithCache(context.Background(), request) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -489,28 +460,22 @@ func BenchmarkGetModelChatTemplate(b *testing.B) { b.ReportMetric(float64(warmAvg.Nanoseconds()), "ns/op_warm") } -// BenchmarkRenderJinjaTemplate benchmarks the template rendering performance. -func BenchmarkRenderJinjaTemplate(b *testing.B) { +// BenchmarkApplyChatTemplate benchmarks the template rendering performance. +func BenchmarkApplyChatTemplate(b *testing.B) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements err := preprocessing.ClearCaches(context.Background()) require.NoError(b, err, "Failed to clear caches") - // Get template first - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: "ibm-granite/granite-3.3-8b-instruct", - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - require.NoError(b, err, "Failed to get template for benchmark") - - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + request := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: "ibm-granite/granite-3.3-8b-instruct", + }, + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, }, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, } // Track first iteration time and total time @@ -520,7 +485,7 @@ func BenchmarkRenderJinjaTemplate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - _, err := wrapper.RenderChatTemplate(context.Background(), request) + _, err := wrapper.ApplyChatTemplate(context.Background(), request) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -596,8 +561,11 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { wrapper := getGlobalWrapper() // Test case based on the provided vLLM request - request := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + request := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: modelName, + }, + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "What is the weather in Paris?"}, {Role: "assistant", Content: "Let me check that for you."}, }, @@ -614,29 +582,11 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { }, } - // Step 1: Get the chat template from the specified model - templateRequest := preprocessing.FetchChatTemplateRequest{ - Model: modelName, - } - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), templateRequest) - require.NoError(t, err, "Failed to get chat template") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - - // Step 2: Update the request with the actual template and template variables - request.ChatTemplate = template - if templateVars != nil { - // Use the template variables from the model (contains special tokens like eos_token) - request.ChatTemplateKWArgs = templateVars - } - - // Step 3: Render the conversation with the template - response, err := wrapper.RenderChatTemplate(context.Background(), request) + // Step 1: Render the conversation with the template + renderedOutput, err := wrapper.ApplyChatTemplate(context.Background(), request) require.NoError(t, err, "Failed to render chat template") - require.Len(t, response.RenderedChats, 1, "Should have one rendered chat") - - renderedOutput := response.RenderedChats[0] - // Step 4: Compare results with flexible date handling + // Step 2: Compare results with flexible date handling compareVLLMOutput(t, renderedOutput, expectedVLLMOutput) } @@ -679,74 +629,58 @@ func compareVLLMOutput(t *testing.T, renderedOutput, expectedVLLMOutput string) t.Fail() // Mark test as failed } -// TestFetchChatTemplateLocalPath tests fetching chat templates from local paths. -func TestFetchChatTemplateLocalPath(t *testing.T) { +// TestLoadTokenizerWithCacheLocalPath tests fetching chat templates from local paths. +func TestLoadTokenizerWithCacheLocalPath(t *testing.T) { wrapper := getGlobalWrapper() // Get the path to the test model tokenizer // The testdata directory is in pkg/tokenization/testdata testModelPath := "../../tokenization/testdata/test-model" - request := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, + request := &preprocessing.LoadTokenizerWithCacheRequest{ + Model: testModelPath, + IsLocal: true, } // Fetch the chat template - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + err := wrapper.LoadTokenizerWithCache(context.Background(), request) // Assertions - require.NoError(t, err, "FetchChatTemplate should not return an error for local path") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") - - // Verify the template contains expected content - assert.Contains(t, template, "messages", "ChatTemplate should contain messages variable") - t.Logf("Fetched local template: %s", template) - t.Logf("Template vars: %+v", templateVars) + require.NoError(t, err, "LoadTokenizerWithCache should not return an error for local path") } -// TestRenderChatTemplateWithLocalTemplate tests rendering with a locally fetched template. -func TestRenderChatTemplateWithLocalTemplate(t *testing.T) { +// TestApplyChatTemplateWithLocalTemplate tests rendering with a locally fetched template. +func TestApplyChatTemplateWithLocalTemplate(t *testing.T) { wrapper := getGlobalWrapper() // Get the path to the test model tokenizer testModelPath := "../../tokenization/testdata/test-model" - // First, fetch the template - fetchRequest := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, - } - - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), fetchRequest) - require.NoError(t, err, "FetchChatTemplate should not return an error") - // Now render a conversation using the fetched template - renderRequest := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: []preprocessing.ChatMessage{ + renderRequest := &preprocessing.ApplyChatTemplateRequest{ + LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ + Model: testModelPath, + IsLocal: true, + }, + + Conversation: []preprocessing.Conversation{ {Role: "user", Content: "Hello from local tokenizer!"}, {Role: "assistant", Content: "Hi! I'm using a locally loaded template."}, }, - ChatTemplate: template, - ChatTemplateKWArgs: templateVars, } - response, err := wrapper.RenderChatTemplate(context.Background(), renderRequest) - require.NoError(t, err, "RenderChatTemplate should not return an error") - assert.NotNil(t, response, "Response should not be nil") - assert.NotEmpty(t, response.RenderedChats, "Rendered chats should not be empty") + rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) + require.NoError(t, err, "ApplyChatTemplate should not return an error") // Verify the rendered content - rendered := response.RenderedChats[0] assert.Contains(t, rendered, "Hello from local tokenizer!", "Rendered content should contain user message") assert.Contains(t, rendered, "Hi! I'm using a locally loaded template.", "Rendered content should contain assistant message") t.Logf("Rendered chat with local template:\n%s", rendered) } -// TestFetchChatTemplateLocalPathCaching tests that local templates are cached properly. -func TestFetchChatTemplateLocalPathCaching(t *testing.T) { +// TestLoadTokenizerWithCacheLocalPathCaching tests that local templates are cached properly. +func TestLoadTokenizerWithCacheLocalPathCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches first @@ -754,75 +688,62 @@ func TestFetchChatTemplateLocalPathCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") testModelPath := "../../tokenization/testdata/test-model" - request := preprocessing.FetchChatTemplateRequest{ - Model: testModelPath, - IsLocalPath: true, + request := &preprocessing.LoadTokenizerWithCacheRequest{ + Model: testModelPath, + IsLocal: true, } // First call - cache miss start := time.Now() - template1, vars1, err := wrapper.FetchChatTemplate(context.Background(), request) + err = wrapper.LoadTokenizerWithCache(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - cache hit start = time.Now() - template2, vars2, err := wrapper.FetchChatTemplate(context.Background(), request) + err = wrapper.LoadTokenizerWithCache(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") - // Verify results are identical - assert.Equal(t, template1, template2, "Cached and non-cached templates should be identical") - assert.Equal(t, vars1, vars2, "Cached and non-cached vars should be identical") - // Cache hit should be faster t.Logf("First call (cache miss): %v, Second call (cache hit): %v, Speedup: %.1fx", duration1, duration2, float64(duration1)/float64(duration2)) assert.Less(t, duration2, duration1, "Cache hit should be faster than cache miss") } -// TestFetchChatTemplateLocalPathWithFile tests loading from a specific tokenizer.json file path. -func TestFetchChatTemplateLocalPathWithFile(t *testing.T) { +// TestLoadTokenizerWithCacheLocalPathWithFile tests loading from a specific tokenizer.json file path. +func TestLoadTokenizerWithCacheLocalPathWithFile(t *testing.T) { wrapper := getGlobalWrapper() // Test with the full path to tokenizer.json //nolint:gosec // This is a test file path, not a credential testTokenizerPath := "../../tokenization/testdata/test-model/tokenizer.json" - request := preprocessing.FetchChatTemplateRequest{ - Model: testTokenizerPath, - IsLocalPath: true, + request := &preprocessing.LoadTokenizerWithCacheRequest{ + Model: testTokenizerPath, + IsLocal: true, } // Fetch the chat template - should extract directory and load from there - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) + err := wrapper.LoadTokenizerWithCache(context.Background(), request) + require.NoError(t, err, "LoadTokenizerWithCache should handle file path and extract directory") - // Assertions - require.NoError(t, err, "FetchChatTemplate should handle file path and extract directory") - assert.NotEmpty(t, template, "ChatTemplate should not be empty") - assert.NotNil(t, templateVars, "ChatTemplate vars should not be nil") - assert.Contains(t, template, "messages", "ChatTemplate should contain messages variable") - - t.Logf("Fetched template from file path: %s", template) + t.Logf("Loaded tokenizer from file path: %s", testTokenizerPath) } -// TestFetchChatTemplateLocalPathNonExistent tests error handling for non-existent local paths. -func TestFetchChatTemplateLocalPathNonExistent(t *testing.T) { +// TestLoadTokenizerWithCacheLocalPathNonExistent tests error handling for non-existent local paths. +func TestLoadTokenizerWithCacheLocalPathNonExistent(t *testing.T) { wrapper := getGlobalWrapper() - request := preprocessing.FetchChatTemplateRequest{ - Model: "/non/existent/path", - IsLocalPath: true, + request := &preprocessing.LoadTokenizerWithCacheRequest{ + Model: "/non/existent/path", + IsLocal: true, } // This should return an error - template, templateVars, err := wrapper.FetchChatTemplate(context.Background(), request) - + err := wrapper.LoadTokenizerWithCache(context.Background(), request) // Assertions - assert.Error(t, err, "FetchChatTemplate should return an error for non-existent path") - assert.Empty(t, template, "ChatTemplate should be empty on error") - assert.Nil(t, templateVars, "ChatTemplate vars should be nil on error") - + assert.Error(t, err, "LoadTokenizerWithCache should return an error for non-existent path") t.Logf("Expected error for non-existent path: %v", err) } diff --git a/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py b/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py deleted file mode 100644 index b87fbbefd..000000000 --- a/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2025 The llm-d Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env python3 -""" -Standalone wrapper for render_jinja_template function from transformers. -""" - -import json -import logging -import sys -from typing import Optional, Union - -# Import core functions from transformers - moved to function level to avoid import errors -TRANSFORMERS_AVAILABLE = None # Will be set when first needed - -def _ensure_transformers_available(): - """Ensure transformers is available, importing it if needed.""" - global TRANSFORMERS_AVAILABLE - if TRANSFORMERS_AVAILABLE is None: - try: - print("[Python] Attempting to import transformers...") - from transformers.utils.chat_template_utils import render_jinja_template as transformers_render_jinja_template, get_json_schema - from transformers import AutoTokenizer - print("[Python] Successfully imported transformers!") - TRANSFORMERS_AVAILABLE = True - return True - except ImportError as e: - print(f"[Python] Failed to import transformers: {e}") - print("[Python] Ensure the 'transformers' library is installed in the Python environment.") - TRANSFORMERS_AVAILABLE = False - return False - return TRANSFORMERS_AVAILABLE - -# Basic logging setup -logger = logging.getLogger(__name__) - -# Module-level cache for templates -_template_cache = {} -_cache_lock = None - -def _get_cache_lock(): - """Get or create a threading lock for cache access.""" - global _cache_lock - if _cache_lock is None: - import threading - _cache_lock = threading.Lock() - return _cache_lock - - -def _collect_template_vars(tokenizer): - """Collect extra rendering variables from a tokenizer.""" - kwargs = {} - for k in ["bos_token", "eos_token", "eot_token", "pad_token", "unk_token", "sep_token", "additional_special_tokens"]: - v = getattr(tokenizer, k, None) - if v is not None: - kwargs[k] = v - return kwargs - - -def clear_caches(): - """Clear all caches for testing purposes.""" - lock = _get_cache_lock() - with lock: - global _template_cache - _template_cache.clear() - return "Caches cleared" - - -def render_jinja_template(request_json): - """ - Render a chat template using the transformers library. - This function is aligned with the Go cgo_functions.go structs. - - Args: - request_json (str): JSON string containing the request parameters: - - conversations (list): List of conversation lists - - chat_template (str, optional): The template to use - - tools (list, optional): Tool schemas - - documents (list, optional): Document schemas - - return_assistant_tokens_mask (bool, optional): Whether to return assistant tokens mask - - continue_final_message (bool, optional): Whether to continue final message - - add_generation_prompt (bool, optional): Whether to add generation prompt - - kwargs (dict, optional): Additional rendering variables - Returns: - str: JSON string containing 'rendered_chats' and 'generation_indices' keys. - """ - if not _ensure_transformers_available(): - raise ImportError("transformers library is required for render_jinja_template") - - # Import the modules we need - from transformers.utils.chat_template_utils import render_jinja_template as transformers_render_jinja_template - - # Parse the JSON request - request = json.loads(request_json) - - # Align Go's `messages` field with transformers' `conversations` parameter. - if 'messages' in request: - request['conversations'] = [request.pop('messages')] # wrap to match expected format - - try: - # Get template_vars and spread them as individual arguments - template_vars = request.pop('chat_template_kwargs', {}) - request.update(template_vars) - - rendered_chats, generation_indices = transformers_render_jinja_template(**request) - - except Exception as e: - raise - - # Return as JSON string, aligning with the Go response struct. - result = json.dumps({ - "rendered_chats": rendered_chats, - "generation_indices": generation_indices - }) - return result - - -def get_model_chat_template(request_json): - """ - Load a tokenizer from Hugging Face Hub or local path and return its chat template string and required variables. - Args: - request_json (str): JSON string containing the request parameters: - - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). - - chat_template (str, optional): The template name or string to use. - - tools (list[dict], optional): Tool schemas to pass. - - revision (str, optional): Model revision. - - token (str, optional): Hugging Face token for private models. - - is_local_path (bool, optional): Whether the model is a local path (default: False). - Returns: - str: JSON string containing 'template' and 'kwargs' keys, aligning with the Go response struct. - """ - if not _ensure_transformers_available(): - print("[Python] get_model_chat_template ERROR - Transformers not available") - raise ImportError("transformers library is required for get_model_chat_template") - - # Parse the JSON request - request = json.loads(request_json) - - model_name = request.get("model") - chat_template = request.get("chat_template") - tools = request.get("tools") - revision = request.get("revision") - token = request.get("token") - is_local_path = request.get("is_local_path", False) - - if not model_name: - print("[Python] get_model_chat_template ERROR - model_name is required") - raise ValueError("model_name is required in request") - - # Create cache key - cache_key = f"{model_name}:{revision or 'main'}:{token or 'none'}:{is_local_path}" - - # Check cache first - lock = _get_cache_lock() - with lock: - if cache_key in _template_cache: - cached_result = _template_cache[cache_key] - # If a specific chat_template was requested, override the cached template - if chat_template is not None: - cached_result["template"] = chat_template - return json.dumps(cached_result) - - # Import the modules we need - from transformers import AutoTokenizer - import os - - # Determine if we're loading from local path or HuggingFace - if is_local_path: - # For local paths, model_name can be either a directory containing tokenizer files - # or a path to a specific tokenizer file. Ensure we extract the directory if needed. - if os.path.isfile(model_name): - # If it's a file path (tokenizer.json), get the directory - tokenizer_dir = os.path.dirname(model_name) - else: - # If it's already a directory, use it directly - tokenizer_dir = model_name - - print(f"[Python] Loading tokenizer from local path: {tokenizer_dir}") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True, trust_remote_code=True) - else: - # Load from Hugging Face - print(f"[Python] Loading tokenizer from HuggingFace: {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name, revision=revision, token=token, trust_remote_code=True) - - template = tokenizer.chat_template if chat_template is None else chat_template - - # Collect special tokens - template_vars = _collect_template_vars(tokenizer) - - # Cache the result, aligning with the Go response struct. - result = {"chat_template": template, "chat_template_kwargs": template_vars} - with lock: - _template_cache[cache_key] = result.copy() # Cache a copy to avoid reference issues - - return json.dumps(result) - - -def main(): - """Example usage and testing function.""" - if not _ensure_transformers_available(): - print("Error: transformers library is required but not available") - print("Please install transformers: pip install transformers") - return - - if len(sys.argv) < 2: - print("Usage: python render_jinja_template_wrapper.py [conversation_json]") - print("Example:") - print('python render_jinja_template_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"') - return - - chat_template = sys.argv[1] - - # Default conversation if none provided - conversation = [ - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there! How can I help you today?"} - ] - - if len(sys.argv) > 2: - try: - conversation = json.loads(sys.argv[2]) - except json.JSONDecodeError: - print("Error: Invalid JSON for conversation") - return - - try: - # Construct the request JSON string similar to how Go would - request_str = json.dumps({ - "conversations": [conversation], - "chat_template": chat_template - }) - response_str = render_jinja_template(request_str) - response_data = json.loads(response_str) - - rendered = response_data['rendered_chats'] - generation_indices = response_data['generation_indices'] - - print("Rendered chat:") - print(rendered[0]) - if generation_indices and len(generation_indices) > 0 and generation_indices[0]: - print(f"Generation indices: {generation_indices[0]}") - except Exception as e: - print(f"Error: {e}") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/requirements.txt b/pkg/preprocessing/chat_completions/requirements.txt index 649c31de8..6f48e6c24 100644 --- a/pkg/preprocessing/chat_completions/requirements.txt +++ b/pkg/preprocessing/chat_completions/requirements.txt @@ -1,7 +1,3 @@ ---extra-index-url https://download.pytorch.org/whl/cpu - -packaging==24.2 -pillow==11.2.1 -torch==2.5.1 -transformers>=4.53.0,<4.57.2 -jinja2>=2.11 +--index-url https://download.pytorch.org/whl/cpu +--extra-index-url https://pypi.org/simple +vllm-cpu>=0.11.0 \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py new file mode 100644 index 000000000..4ccd022f9 --- /dev/null +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -0,0 +1,153 @@ +# Copyright 2025 The llm-d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Standalone wrapper for tokenizer from vllm. +""" + +import json +import logging +import os +import sys +from vllm.transformers_utils.tokenizer import get_tokenizer + +# Basic logging setup +logger = logging.getLogger(__name__) + +_tokenizer_cache = {} + +def clear_caches(): + """Clear the tokenizer cache for testing purposes.""" + _tokenizer_cache.clear() + return "Tokenizer caches cleared" + +def apply_chat_template(request_json): + """ + Render a chat template using the transformers library. + This function is aligned with the Go cgo_functions.go structs. + + Args: + request_json (str): JSON string containing the request parameters: + - is_local (bool, optional): Whether the model is local. + - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). + - revision (str, optional): Model revision. + - token (str, optional): Hugging Face token for private models. + - conversation (list): List of conversation lists + - chat_template (str, optional): The template to use + - tools (list, optional): Tool schemas + - documents (list, optional): Document schemas + - return_assistant_tokens_mask (bool, optional): Whether to return assistant tokens mask + - continue_final_message (bool, optional): Whether to continue final message + - add_generation_prompt (bool, optional): Whether to add generation prompt + - chat_template_kwargs (dict, optional): Additional rendering variables + + Returns: + str: The rendered chat template as a string. + """ + + try: + # Parse the JSON request + request = json.loads(request_json) + tokenizer_request = request.pop("load_tokenizer_with_cache_request", request) + tokenizer = load_tokenizer_with_cache(json.dumps(tokenizer_request)) + + # Get template_vars and spread them as individual arguments + template_vars = request.pop('chat_template_kwargs', {}) + request.update(template_vars) + + request["tokenize"] = False + return tokenizer.apply_chat_template(**request) + + except Exception as e: + raise RuntimeError(f"Error applying chat template: {e}") from e + +def load_tokenizer_with_cache(request_json): + """ + Initialize and cache the tokenizer based on the request. + Args: + request_json (str): JSON string containing the request parameters: + - is_local (bool, optional): Whether the model is local. + - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). + - revision (str, optional): Model revision. + - token (str, optional): Hugging Face token for private models. + - download_dir (str, optional): Directory to download the model. + Returns: + tokenizer: The initialized tokenizer object. + """ + # Parse the JSON request + request = json.loads(request_json) + + try: + model_name = request.pop("model") + revision = request.get("revision", None) + is_local = request.pop("is_local", False) + token = request.pop("token", "") + download_dir = request.pop("download_dir", None) + + if is_local and os.path.isfile(model_name): + # If it's a file path (tokenizer.json), get the directory + model_name = os.path.dirname(model_name) + + cache_key = f"{model_name}:{revision or 'main'}:{is_local}" + tokenizer = _tokenizer_cache.get(cache_key) + if not tokenizer is None: + return tokenizer + os.environ["HF_TOKEN"] = token + tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir) + _tokenizer_cache[cache_key] = tokenizer + return tokenizer + except Exception as e: + raise RuntimeError(f"Error initializing tokenizer: {e}") from e + +def main(): + """Example usage and testing function.""" + + if len(sys.argv) < 2: + print("Usage: python tokenizer_wrapper.py [conversation_json]") + print("Example:") + print('python tokenizer_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"') + return + + chat_template = sys.argv[1] + + # Default conversation if none provided + conversation = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there! How can I help you today?"} + ] + + if len(sys.argv) > 2: + try: + conversation = json.loads(sys.argv[2]) + except json.JSONDecodeError: + print("Error: Invalid JSON for conversation") + return + + try: + # Construct the request JSON string similar to how Go would + request_str = json.dumps({ + "model": "facebook/opt-125m", + "conversation": [conversation], + "chat_template": chat_template + }) + response = apply_chat_template(request_str) + + print("Rendered chat:") + print(response[0]) + except Exception as e: + print(f"Error: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index be5fee01e..f28750021 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -69,7 +69,7 @@ type tokenizationResponse struct { // Task represents a unit of work for tokenizing a prompt. type Task struct { - RenderReq *preprocessing.RenderJinjaTemplateRequest + RenderReq *preprocessing.ApplyChatTemplateRequest Prompt string ModelName string ResultCh chan<- tokenizationResponse // nil => fire-and-forget @@ -151,7 +151,7 @@ func (pool *Pool) EnqueueTokenization(prompt string) { } // Tokenize queues a task and blocks until the final result is available. -func (pool *Pool) Tokenize(renderReq *preprocessing.RenderJinjaTemplateRequest, prompt string) []uint32 { +func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) []uint32 { resultCh := make(chan tokenizationResponse, 1) pool.queue.Add(Task{ RenderReq: renderReq, @@ -214,20 +214,22 @@ func (pool *Pool) workerLoop(_ int) { // processTask tokenizes the prompt and updates the indexer. // It sends exactly one response (success or error) if ResultCh is provided. func (pool *Pool) processTask(task Task) error { + addSpecialToken := true if task.RenderReq != nil { var err error - task.Prompt, err = pool.tokenizer.RenderChatTemplate(pool.modelName, task.RenderReq) + task.Prompt, err = pool.tokenizer.ApplyChatTemplate(pool.modelName, task.RenderReq) if err != nil { log.Log.Error(err, "failed to render chat template") return err } + addSpecialToken = false } tokenIDs, overlapRatio := pool.indexer.FindLongestContainedTokens(task.Prompt) // if the overlap ratio is low, get the full tokenization if overlapRatio < pool.minPrefixOverlapRatio { - tokens, offsets, err := pool.tokenizer.Encode(task.Prompt, pool.modelName) + tokens, offsets, err := pool.tokenizer.Encode(task.Prompt, pool.modelName, addSpecialToken) if err != nil { log.Log.Error(err, "failed to encode tokens", "prompt", task.Prompt) return err diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index 1756dd9f7..2cfc15ad4 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -51,15 +51,15 @@ type MockTokenizer struct { mock.Mock } -func (m *MockTokenizer) RenderChatTemplate( - prompt string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (m *MockTokenizer) ApplyChatTemplate( + prompt string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { args := m.Called(prompt, renderReq) return args.String(0), args.Error(1) } -func (m *MockTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { - args := m.Called(input, modelName) +func (m *MockTokenizer) Encode(input, modelName string, addSpecialToken bool) ([]uint32, []tokenizers.Offset, error) { + args := m.Called(input, modelName, addSpecialToken) return args.Get(0).([]uint32), args.Get(1).([]tokenizers.Offset), args.Error(2) //nolint:errcheck // return mocked values } @@ -107,7 +107,7 @@ func TestPool_ProcessTask(t *testing.T) { // Mock FindLongestContainedTokens to return low overlap ratio mockIndexer.On("FindLongestContainedTokens", task.Prompt).Return([]uint32{}, 0.0) - mockTokenizer.On("Encode", task.Prompt, testModelName).Return(expectedTokens, expectedOffsets, nil) + mockTokenizer.On("Encode", task.Prompt, testModelName, true).Return(expectedTokens, expectedOffsets, nil) // Verify that indexer receives exactly the same tokens and offsets that tokenizer returned mockIndexer.On("AddTokenization", task.Prompt, expectedTokens, expectedOffsets).Return(nil) @@ -130,7 +130,7 @@ func TestPool_WorkerLoop(t *testing.T) { "successful task processing": { setupMocks: func(mi *MockIndexer, mt *MockTokenizer) { mi.On("FindLongestContainedTokens", "test prompt").Return([]uint32{}, 0.0) - mt.On("Encode", "test prompt", testModelName).Return([]uint32{1, 2, 3}, []tokenizers.Offset{{0, 4}}, nil) + mt.On("Encode", "test prompt", testModelName, true).Return([]uint32{1, 2, 3}, []tokenizers.Offset{{0, 4}}, nil) mi.On("AddTokenization", "test prompt", []uint32{1, 2, 3}, []tokenizers.Offset{{0, 4}}).Return(nil) }, genTasks: func() ([]Task, chan tokenizationResponse) { @@ -141,7 +141,7 @@ func TestPool_WorkerLoop(t *testing.T) { "task with result channel": { setupMocks: func(mi *MockIndexer, mt *MockTokenizer) { mi.On("FindLongestContainedTokens", "test with channel").Return([]uint32{}, 0.0) - mt.On("Encode", "test with channel", testModelName).Return([]uint32{10, 20, 30}, []tokenizers.Offset{{0, 4}}, nil) + mt.On("Encode", "test with channel", testModelName, true).Return([]uint32{10, 20, 30}, []tokenizers.Offset{{0, 4}}, nil) mi.On("AddTokenization", "test with channel", []uint32{10, 20, 30}, []tokenizers.Offset{{0, 4}}).Return(nil) }, genTasks: func() ([]Task, chan tokenizationResponse) { @@ -176,7 +176,7 @@ func TestPool_WorkerLoop(t *testing.T) { offsets := []tokenizers.Offset{{0, 6}} mi.On("FindLongestContainedTokens", prompt).Return([]uint32{}, 0.0).Once() - mt.On("Encode", prompt, testModelName).Return(tokens, offsets, nil).Once() + mt.On("Encode", prompt, testModelName, true).Return(tokens, offsets, nil).Once() mi.On("AddTokenization", prompt, tokens, offsets).Return(nil).Once() } }, @@ -198,7 +198,7 @@ func TestPool_WorkerLoop(t *testing.T) { setupMocks: func(mi *MockIndexer, mt *MockTokenizer) { // Mock will fail every time, causing retries mi.On("FindLongestContainedTokens", "failing prompt").Return([]uint32{}, 0.0) - mt.On("Encode", "failing prompt", testModelName).Return( + mt.On("Encode", "failing prompt", testModelName, true).Return( []uint32{}, []tokenizers.Offset{}, assert.AnError) }, genTasks: func() ([]Task, chan tokenizationResponse) { diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 9339de27a..a5fdae49a 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -34,9 +34,9 @@ import ( // Tokenizer interface defines the methods for tokenization. type Tokenizer interface { - RenderChatTemplate(string, *preprocessing.RenderJinjaTemplateRequest) (string, error) + ApplyChatTemplate(string, *preprocessing.ApplyChatTemplateRequest) (string, error) // Encode tokenizes the input string and returns the token IDs and offsets. - Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) + Encode(input, modelName string, addSpecialToken bool) ([]uint32, []tokenizers.Offset, error) Type() string } @@ -258,8 +258,6 @@ func parseHFCacheModelName(dirName string) (string, bool) { type tokenizerProvider interface { get(modelName string) (*tokenizers.Tokenizer, error) - - getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) } // CachedTokenizer implements the Tokenizer interface for a specific model. @@ -271,9 +269,18 @@ type CachedTokenizer struct { chatTemplateRenderer *preprocessing.ChatTemplatingProcessor } +type HFCachedTokenizer struct { + CachedTokenizer + hfTokenizerConfig *HFTokenizerConfig +} +type LocalCachedTokenizer struct { + CachedTokenizer + localTokenizerConfig *LocalTokenizerConfig +} + // NewCachedHFTokenizer creates a new instance of CachedTokenizer downloading tokenizer configs from HuggingFace with // the provided configuration. -func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (Tokenizer, error) { +func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (*HFCachedTokenizer, error) { tokenizerProvider := newHFTokenizerProvider(config) tokenizer, err := tokenizerProvider.get(modelID) if err != nil { @@ -286,10 +293,23 @@ func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (Tokenizer, return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - return &CachedTokenizer{ - tokenizer: tokenizer, - tokenizerProvider: tokenizerProvider, - chatTemplateRenderer: chatTemplateRenderer, + ctx := context.TODO() + if err := chatTemplateRenderer.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ + IsLocal: false, + Model: modelID, + DownloadDir: config.TokenizersCacheDir, + Token: config.HuggingFaceToken, + }); err != nil { + return nil, fmt.Errorf("failed to load tokenizer with cache: %w", err) + } + + return &HFCachedTokenizer{ + CachedTokenizer: CachedTokenizer{ + tokenizer: tokenizer, + tokenizerProvider: tokenizerProvider, + chatTemplateRenderer: chatTemplateRenderer, + }, + hfTokenizerConfig: config, }, nil } @@ -302,7 +322,7 @@ func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (Tokenizer, // - Reducing startup latency by avoiding downloads // // The tokenizer is initialized for a specific model at creation time. -func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (Tokenizer, error) { +func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (*LocalCachedTokenizer, error) { if err := discoverLocalTokenizerMap(&config); err != nil { return nil, fmt.Errorf("failed to discover local tokenizer map: %w", err) } @@ -321,48 +341,73 @@ func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (Tok return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - return &CachedTokenizer{ - tokenizer: tokenizer, - tokenizerProvider: tokenizerProvider, - chatTemplateRenderer: chatTemplater, + path, ok := config.ModelTokenizerMap[modelName] + if !ok { + return nil, fmt.Errorf("tokenizer for model %q not found", modelName) + } + + ctx := context.TODO() + if err := chatTemplater.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ + IsLocal: true, + Model: path, + }); err != nil { + return nil, fmt.Errorf("failed to load tokenizer with cache: %w", err) + } + + return &LocalCachedTokenizer{ + CachedTokenizer: CachedTokenizer{ + tokenizer: tokenizer, + tokenizerProvider: tokenizerProvider, + chatTemplateRenderer: chatTemplater, + }, + localTokenizerConfig: &config, }, nil } -func (t *CachedTokenizer) RenderChatTemplate( - modelName string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (t *LocalCachedTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, ) (string, error) { ctx := context.TODO() - if renderReq.ChatTemplate == "" { - req, err := t.tokenizerProvider.getFetchChatTemplateRequest(modelName) - if err != nil { - return "", fmt.Errorf("failed to create fetch chat template request: %w", err) - } - renderReq.ChatTemplate, renderReq.ChatTemplateKWArgs, err = t.chatTemplateRenderer.FetchChatTemplate( - ctx, req, - ) - if err != nil { - return "", fmt.Errorf("failed to fetch chat template: %w", err) - } + req.LoadTokenizerWithCacheRequest.IsLocal = true + path, ok := t.localTokenizerConfig.ModelTokenizerMap[modelName] + if !ok { + return "", fmt.Errorf("tokenizer for model %q not found", modelName) + } + req.LoadTokenizerWithCacheRequest.Model = filepath.Dir(path) + res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to render chat template: %w", err) } - res, err := t.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq) + return res, nil +} + +func (t *HFCachedTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, +) (string, error) { + ctx := context.TODO() + + req.LoadTokenizerWithCacheRequest.IsLocal = false + req.LoadTokenizerWithCacheRequest.DownloadDir = t.hfTokenizerConfig.TokenizersCacheDir + req.LoadTokenizerWithCacheRequest.Token = t.hfTokenizerConfig.HuggingFaceToken + res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) if err != nil { return "", fmt.Errorf("failed to render chat template: %w", err) } - return res.RenderedChats[0], nil + return res, nil } // Encode converts a string into token IDs. // The modelName parameter is ignored since this tokenizer is bound to a specific model. -func (t *CachedTokenizer) Encode(input, _ string) ([]uint32, []tokenizers.Offset, error) { +func (t *CachedTokenizer) Encode(input, _ string, addSpecialToken bool) ([]uint32, []tokenizers.Offset, error) { encodeOptions := []tokenizers.EncodeOption{ tokenizers.WithReturnTypeIDs(), tokenizers.WithReturnOffsets(), } - resp := t.tokenizer.EncodeWithOptions(input, false, encodeOptions...) + resp := t.tokenizer.EncodeWithOptions(input, addSpecialToken, encodeOptions...) return resp.IDs, resp.Offsets, nil } @@ -411,14 +456,6 @@ func (p *hfTokenizerProvider) get(modelName string) (*tokenizers.Tokenizer, erro return tokenizers.FromPretrained(modelName, p.cfgOpt) } -func (p *hfTokenizerProvider) getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) { - return preprocessing.FetchChatTemplateRequest{ - Model: modelName, - Token: p.authToken, - IsLocalPath: false, - }, nil -} - // localTokenizerProvider implements tokenizerProvider by loading tokenizers from local files. // It looks up the tokenizer file path in the configuration mapping and loads it from disk. type localTokenizerProvider struct { @@ -436,20 +473,6 @@ func (p *localTokenizerProvider) get(modelName string) (*tokenizers.Tokenizer, e return tokenizers.FromFile(path) } -func (p *localTokenizerProvider) getFetchChatTemplateRequest(modelName string) (preprocessing.FetchChatTemplateRequest, error) { - req := preprocessing.FetchChatTemplateRequest{ - IsLocalPath: true, - } - - path, ok := p.cfg.ModelTokenizerMap[modelName] - if !ok { - return req, fmt.Errorf("tokenizer for model %q not found", modelName) - } - req.Model = filepath.Dir(path) - - return req, nil -} - // CompositeTokenizer implements the Tokenizer interface with a fallback mechanism. // It tries each tokenizer in order until one succeeds. This allows for graceful // fallback from local tokenizers to HuggingFace tokenizers. @@ -471,18 +494,18 @@ type CompositeTokenizer struct { Tokenizers []Tokenizer } -func (c *CompositeTokenizer) RenderChatTemplate( - modelName string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (c *CompositeTokenizer) ApplyChatTemplate( + modelName string, req *preprocessing.ApplyChatTemplateRequest, ) (string, error) { var rErr error for _, tokenizer := range c.Tokenizers { - copiedRenderReq, err := renderReq.DeepCopy() + copiedReq, err := req.DeepCopy() if err != nil { rErr = multierr.Append(rErr, fmt.Errorf("failed to copy render request: %w", err)) continue } start := time.Now() - rendered, err := tokenizer.RenderChatTemplate(modelName, copiedRenderReq) + rendered, err := tokenizer.ApplyChatTemplate(modelName, copiedReq) metrics.RenderChatTemplateLatency.WithLabelValues(tokenizer.Type()).Observe(time.Since(start).Seconds()) if err != nil { rErr = multierr.Append(rErr, err) @@ -503,11 +526,11 @@ func (c *CompositeTokenizer) RenderChatTemplate( // 4. If all fail, returns all accumulated errors // // This enables prioritizing local tokenizers while maintaining HuggingFace as a fallback. -func (c *CompositeTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { +func (c *CompositeTokenizer) Encode(input, modelName string, addSpecialToken bool) ([]uint32, []tokenizers.Offset, error) { var rErr error for _, tokenizer := range c.Tokenizers { start := time.Now() - ids, offsets, err := tokenizer.Encode(input, modelName) + ids, offsets, err := tokenizer.Encode(input, modelName, addSpecialToken) metrics.TokenizationLatency.WithLabelValues(tokenizer.Type()).Observe(time.Since(start).Seconds()) if err != nil { rErr = multierr.Append(rErr, err) diff --git a/pkg/tokenization/tokenizer_test.go b/pkg/tokenization/tokenizer_test.go index 8822b6f03..48a7e247a 100644 --- a/pkg/tokenization/tokenizer_test.go +++ b/pkg/tokenization/tokenizer_test.go @@ -36,13 +36,13 @@ type DummyTokenizer struct { returnError bool } -func (d *DummyTokenizer) RenderChatTemplate( - prompt string, renderReq *preprocessing.RenderJinjaTemplateRequest, +func (d *DummyTokenizer) ApplyChatTemplate( + prompt string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { return prompt, nil } -func (d *DummyTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { +func (d *DummyTokenizer) Encode(input, modelName string, addSpecialToken bool) ([]uint32, []tokenizers.Offset, error) { if d.returnError { return nil, nil, fmt.Errorf("dummy tokenizer error") } @@ -81,7 +81,7 @@ func TestCachedHFTokenizer_Encode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenIds, offsets, err := tokenizer.Encode(tt.input, testModelName) + tokenIds, offsets, err := tokenizer.Encode(tt.input, testModelName, true) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) @@ -105,11 +105,11 @@ func TestCachedHFTokenizer_CacheTokenizer(t *testing.T) { input := "test input" // First call - loads tokenizer - tokenIds1, offsets1, err1 := tokenizer.Encode(input, testModelName) + tokenIds1, offsets1, err1 := tokenizer.Encode(input, testModelName, true) require.NoError(t, err1) // Second call - should use cached tokenizer - tokenIds2, offsets2, err2 := tokenizer.Encode(input, testModelName) + tokenIds2, offsets2, err2 := tokenizer.Encode(input, testModelName, true) require.NoError(t, err2) // Results should be identical @@ -161,7 +161,7 @@ func TestCachedLocalTokenizer_Encode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenIds, offsets, err := tokenizer.Encode(tt.input, tt.modelName) + tokenIds, offsets, err := tokenizer.Encode(tt.input, tt.modelName, true) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) @@ -211,7 +211,7 @@ func TestCompositeTokenizer_FallbackBehavior(t *testing.T) { Tokenizers: []Tokenizer{dummyTokenizer, hfTokenizer}, } - tokenIds, offsets, err := composite.Encode("hello world", testModelName) + tokenIds, offsets, err := composite.Encode("hello world", testModelName, true) assert.NoError(t, err) assert.GreaterOrEqual(t, len(tokenIds), 0) assert.Equal(t, len(tokenIds), len(offsets)) diff --git a/pkg/tokenization/uds_tokenizer.go b/pkg/tokenization/uds_tokenizer.go index d9cc9e1ba..5d29eb5b1 100644 --- a/pkg/tokenization/uds_tokenizer.go +++ b/pkg/tokenization/uds_tokenizer.go @@ -105,7 +105,7 @@ func NewUdsTokenizer(config *UdsTokenizerConfig) (Tokenizer, error) { } // Encode tokenizes the input string and returns the token IDs and offsets. -func (u *UdsTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) { +func (u *UdsTokenizer) Encode(input, modelName string, _ bool) ([]uint32, []tokenizers.Offset, error) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, @@ -129,11 +129,11 @@ func (u *UdsTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.O return tokenized.InputIDs, tokenized.OffsetMapping, nil } -// RenderChatTemplate renders a chat template using the UDS tokenizer service. -func (u *UdsTokenizer) RenderChatTemplate( - _ string, renderReq *preprocessing.RenderJinjaTemplateRequest, +// ApplyChatTemplate renders a chat template using the UDS tokenizer service. +func (u *UdsTokenizer) ApplyChatTemplate( + _ string, renderReq *preprocessing.ApplyChatTemplateRequest, ) (string, error) { - messagesBytes, err := json.Marshal(renderReq.Conversations) + messagesBytes, err := json.Marshal(renderReq.Conversation) if err != nil { return "", fmt.Errorf("failed to marshal chat-completions messages: %w", err) } diff --git a/tests/e2e/redis_mock/e2e_suite_test.go b/tests/e2e/redis_mock/e2e_suite_test.go index 09adc4341..3f54fa21d 100644 --- a/tests/e2e/redis_mock/e2e_suite_test.go +++ b/tests/e2e/redis_mock/e2e_suite_test.go @@ -94,9 +94,9 @@ func (s *KVCacheSuite) SetupTest() { // //nolint:nonamedreturns // named returns keep gocritic unnamedResult satisfied while allowing compact return func (s *KVCacheSuite) promptToEngineAndRequestKeys( - prompt, model string, + prompt, model string, addSpecialToken bool, ) (engineKeys, requestKeys []kvblock.Key) { - tokens, _, err := s.tokenizer.Encode(prompt, model) + tokens, _, err := s.tokenizer.Encode(prompt, model, addSpecialToken) s.Require().NoError(err) requestKeys = s.tokensProcessor.TokensToKVBlockKeys(nil, tokens, model) diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index d3f3f9899..94316fdd3 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -28,6 +28,11 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/tokenization" ) +const ( + localTestModelDir = "testdata/test-model" + localLlama3ModelDir = "testdata/local-llama3" +) + // ChatMessage represents a single message in a conversation. type ChatMessage struct { Role string `json:"role"` @@ -54,11 +59,11 @@ type GetChatTemplateRequest struct { Token string `json:"token,omitempty"` } -// convertToPreprocessingChatMessages converts e2e ChatMessage to preprocessing ChatMessage. -func convertToPreprocessingChatMessages(messages []ChatMessage) []preprocessing.ChatMessage { - result := make([]preprocessing.ChatMessage, len(messages)) +// convertToPreprocessingConversation converts e2e ChatMessage to preprocessing Conversation. +func convertToPreprocessingConversation(messages []ChatMessage) []preprocessing.Conversation { + result := make([]preprocessing.Conversation, len(messages)) for i, msg := range messages { - result[i] = preprocessing.ChatMessage{ + result[i] = preprocessing.Conversation{ Role: msg.Role, Content: msg.Content, } @@ -111,7 +116,7 @@ func (s *KVCacheSuite) TestCacheHit() { prompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." fakePodList := []string{s.Pod1IP} - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName, true) s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) pods, err := s.indexer.GetPodScores(s.ctx, nil, prompt, defaultModelName, fakePodList) @@ -139,7 +144,7 @@ func (s *KVCacheSuite) TestPrefixReduction() { midPrompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." shortPrompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit." - fullPromptEngineKeys, fullPromptRequestKeys := s.promptToEngineAndRequestKeys(fullPrompt, defaultModelName) + fullPromptEngineKeys, fullPromptRequestKeys := s.promptToEngineAndRequestKeys(fullPrompt, defaultModelName, true) fakePodList := []string{s.Pod1IP} // Test 1: Full prompt (no match expected) @@ -163,7 +168,7 @@ func (s *KVCacheSuite) TestPrefixReduction() { s.Len(pods, len(fakePodList), "expected pod scores length to match candidate pods") s.T().Logf("Received pod scores: %+v", pods) - _, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, defaultModelName) + _, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, defaultModelName, true) s.Equal(int(pods[s.Pod1IP]), len(shortPromptRequestKeys), "all short-prompt block keys should have been indexed") } @@ -183,7 +188,7 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.T().Logf("Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores") - shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName) + shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName, true) s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt @@ -193,7 +198,7 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.T().Logf("Received pod scores: %+v", pods) s.Equal(int(pods[s.Pod1IP]), len(shortPromptRequestKeys), "expected pod score to equal number of short prompt block keys") - midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName) + midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName, true) s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: full prompt @@ -223,7 +228,7 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.Empty(pods, "expected no pod scores") // Add entries to the index for the short prompt - shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName) + shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName, true) s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt (should return partial match if indexer picks it up) @@ -233,7 +238,7 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.True(len(pods) > 0, "expected at least one pod score for mid prompt") // Add entries to the index for the mid prompt - midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName) + midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName, true) s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: long prompt (should return higher score) @@ -281,7 +286,7 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { s.Require().NotEmpty(flattenedPrompt, "Flattened prompt should not be empty") // Step 4: Use the flattened prompt for KV-cache lookup (similar to TestBasicE2E). - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", true) fakePodList := []string{s.Pod1IP} // First lookup - should return no scores initially. @@ -355,7 +360,7 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { s.Require().Greater(len(flattenedPrompt), 1000, "Long conversation should produce substantial output") // Step 4: Test KV-cache with the long flattened prompt. - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", true) fakePodList := []string{s.Pod1IP} // First lookup. @@ -395,14 +400,14 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { fakePodList := []string{s.Pod1IP} // Tokenize using local tokenizer - tokens, offsets, err := localTokenizer.Encode(prompt, modelName) + tokens, offsets, err := localTokenizer.Encode(prompt, modelName, true) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets), "tokens and offsets should have same length") s.T().Logf("Local tokenizer produced %d tokens for prompt", len(tokens)) // Convert tokens to KV block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName, true) // Add entries to the index - this verifies the local tokenizer produces valid block keys s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) @@ -415,7 +420,7 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { s.T().Logf("GetPodScores returned score: %v", pods[s.Pod1IP]) // Also verify that tokenizing the same prompt again produces same block keys - tokens2, _, err := localTokenizer.Encode(prompt, modelName) + tokens2, _, err := localTokenizer.Encode(prompt, modelName, true) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, modelName) s.Require().Equal(requestKeys, requestKeys2, "same prompt should produce same block keys") @@ -434,12 +439,8 @@ func (s *KVCacheSuite) TestHFCacheStructureDiscoveryE2E() { testModelPath := filepath.Join(tmpDir, "models--test-org--test-model", "snapshots", "abc123") require.NoError(s.T(), os.MkdirAll(testModelPath, 0o755)) - // Copy the test tokenizer file - srcTokenizer := "testdata/test-model/tokenizer.json" - dstTokenizer := filepath.Join(testModelPath, "tokenizer.json") - srcData, err := os.ReadFile(srcTokenizer) - require.NoError(s.T(), err) - require.NoError(s.T(), os.WriteFile(dstTokenizer, srcData, 0o600)) + // Copy the test tokenizer + require.NoError(s.T(), os.CopyFS(testModelPath, os.DirFS(localTestModelDir))) // Create tokenizer config with auto-discovery config := tokenization.LocalTokenizerConfig{ @@ -457,20 +458,20 @@ func (s *KVCacheSuite) TestHFCacheStructureDiscoveryE2E() { fakePodList := []string{s.Pod1IP} // Tokenize using the auto-discovered HF cache tokenizer - tokens, offsets, err := localTokenizer.Encode(prompt, modelName) + tokens, offsets, err := localTokenizer.Encode(prompt, modelName, true) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets), "tokens and offsets should have same length") s.T().Logf("HF cache auto-discovery produced %d tokens for model %q", len(tokens), modelName) // Convert tokens to KV block keys using promptToEngineAndRequestKeys with local tokenizer - engineKeys1, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName) + engineKeys1, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName, true) // Add entries to the index s.addEntriesToIndex(engineKeys1, requestKeys, fakePodList) // Verify retrieval - tokens2, _, err := localTokenizer.Encode(prompt, modelName) + tokens2, _, err := localTokenizer.Encode(prompt, modelName, true) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, modelName) s.Require().Equal(requestKeys, requestKeys2, "same prompt should produce same block keys") @@ -488,12 +489,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -523,10 +524,10 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { // Step 1: Render the conversation into a flattened prompt using local chat template // This tests the full integration: Go -> CGO -> Python -> Local Tokenizer - renderReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + renderReq := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(conversation), } - renderedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, renderReq) + renderedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, renderReq) s.Require().NoError(err, "RenderChatTemplate should succeed with local tokenizer") s.Require().NotEmpty(renderedPrompt, "Rendered prompt should not be empty") s.T().Logf("Rendered prompt from local template:\n%s", renderedPrompt) @@ -537,14 +538,14 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.Require().Contains(renderedPrompt, "Give me an example", "rendered prompt should contain second user message") // Step 2: Tokenize the rendered prompt using the same local tokenizer - tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName) + tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName, true) s.Require().NoError(err, "Encode should succeed") s.Require().NotEmpty(tokens, "Tokens should not be empty") s.Require().Equal(len(tokens), len(offsets), "Tokens and offsets should have same length") s.T().Logf("Local tokenizer produced %d tokens from rendered chat template", len(tokens)) // Step 3: Convert tokens to KV block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName, true) s.T().Logf("Generated %d KV block keys from rendered conversation", len(requestKeys)) // Step 4: Add to index and verify retrieval (full KV-cache flow) @@ -558,14 +559,14 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.T().Logf("GetPodScores returned score: %v for rendered chat template", pods[s.Pod1IP]) // Also verify by rendering and tokenizing the same conversation again - renderReq2 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + renderReq2 := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(conversation), } - renderedPrompt2, err := localTokenizer.RenderChatTemplate(tc.modelName, renderReq2) + renderedPrompt2, err := localTokenizer.ApplyChatTemplate(tc.modelName, renderReq2) s.Require().NoError(err) s.Require().Equal(renderedPrompt, renderedPrompt2, "Same conversation should render identically") - tokens2, _, err := localTokenizer.Encode(renderedPrompt2, tc.modelName) + tokens2, _, err := localTokenizer.Encode(renderedPrompt2, tc.modelName, true) s.Require().NoError(err) requestKeys2 := s.tokensProcessor.TokensToKVBlockKeys(nil, tokens2, tc.modelName) s.Require().Equal(requestKeys, requestKeys2, "Same conversation should produce same block keys") @@ -584,12 +585,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -618,15 +619,15 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { } // Render and cache the short conversation - shortReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(shortConversation), + shortReq := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(shortConversation), } - shortPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, shortReq) + shortPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, shortReq) s.Require().NoError(err) s.T().Logf("Short prompt length: %d chars", len(shortPrompt)) - shortTokens, _, err := localTokenizer.Encode(shortPrompt, tc.modelName) + shortTokens, _, err := localTokenizer.Encode(shortPrompt, tc.modelName, true) s.Require().NoError(err) - shortEngineKeys, shortRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, tc.modelName) + shortEngineKeys, shortRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, tc.modelName, true) s.addEntriesToIndex(shortEngineKeys, shortRequestKeys, fakePodList) s.T().Logf("Short conversation: %d tokens, %d block keys", len(shortTokens), len(shortRequestKeys)) @@ -645,18 +646,18 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { } // Render and test the extended conversation - extendedReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(extendedConversation), + extendedReq := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(extendedConversation), } - extendedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, extendedReq) + extendedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, extendedReq) s.Require().NoError(err) s.T().Logf("Extended prompt: %q (length: %d)", extendedPrompt, len(extendedPrompt)) s.Require().Greater(len(extendedPrompt), len(shortPrompt), "Extended conversation should be longer") - extendedTokens, _, err := localTokenizer.Encode(extendedPrompt, tc.modelName) + extendedTokens, _, err := localTokenizer.Encode(extendedPrompt, tc.modelName, true) s.Require().NoError(err) - extendedEngineKeys, extendedRequestKeys := s.promptToEngineAndRequestKeys(extendedPrompt, tc.modelName) + extendedEngineKeys, extendedRequestKeys := s.promptToEngineAndRequestKeys(extendedPrompt, tc.modelName, true) s.T().Logf("Extended conversation: %d tokens, %d block keys", len(extendedTokens), len(extendedRequestKeys)) // Some tokenizers use fixed-length encoding with padding (e.g., 512 tokens) @@ -703,12 +704,12 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -743,20 +744,20 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { } // Render with local tokenizer - req1 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + req1 := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(conversation), } - localRendered, err := localTokenizer.RenderChatTemplate(tc.modelName, req1) + localRendered, err := localTokenizer.ApplyChatTemplate(tc.modelName, req1) s.Require().NoError(err) s.Require().NotEmpty(localRendered) // Tokenize with local tokenizer - localTokens, _, err := localTokenizer.Encode(localRendered, tc.modelName) + localTokens, _, err := localTokenizer.Encode(localRendered, tc.modelName, true) s.Require().NoError(err) s.T().Logf("Local tokenizer: rendered=%d chars, tokens=%d", len(localRendered), len(localTokens)) // Add to index and verify with GetPodScores - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(localRendered, tc.modelName) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(localRendered, tc.modelName, true) fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) @@ -767,16 +768,16 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { s.T().Logf("GetPodScores returned score: %v", pods[s.Pod1IP]) // Render the same conversation again to test caching and consistency - req2 := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + req2 := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(conversation), } - localRendered2, err := localTokenizer.RenderChatTemplate(tc.modelName, req2) + localRendered2, err := localTokenizer.ApplyChatTemplate(tc.modelName, req2) s.Require().NoError(err) s.Require().Equal(localRendered, localRendered2, "Rendering the same conversation twice should produce identical output (tests caching)") // Tokenize again - localTokens2, _, err := localTokenizer.Encode(localRendered2, tc.modelName) + localTokens2, _, err := localTokenizer.Encode(localRendered2, tc.modelName, true) s.Require().NoError(err) s.Require().Equal(localTokens, localTokens2, "Tokenizing the same prompt twice should produce identical tokens") @@ -789,7 +790,7 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { // TestLocalTokenizerChatTemplateErrorHandling tests error cases for local chat templates. func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { modelName := "test-model" - testModelDir, err := filepath.Abs("testdata/test-model") + testModelDir, err := filepath.Abs(localTestModelDir) s.Require().NoError(err) localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, tokenization.LocalTokenizerConfig{ @@ -806,19 +807,19 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { } // Test 1: Non-existent model - reqNonExistent := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(conversation), + reqNonExistent := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(conversation), } - _, err = localTokenizer.RenderChatTemplate("non-existent-model", reqNonExistent) + _, err = localTokenizer.ApplyChatTemplate("non-existent-model", reqNonExistent) s.Require().Error(err, "Should return error for non-existent model") s.T().Logf("Expected error for non-existent model: %v", err) // Test 2: Empty conversation emptyConversation := []ChatMessage{} - reqEmpty := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(emptyConversation), + reqEmpty := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(emptyConversation), } - rendered, err := localTokenizer.RenderChatTemplate("test-model", reqEmpty) + rendered, err := localTokenizer.ApplyChatTemplate("test-model", reqEmpty) // This might succeed with empty output or fail depending on template // Either is acceptable behavior if err == nil { @@ -839,12 +840,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { }{ { name: "test-model", - modelDir: "testdata/test-model", + modelDir: localTestModelDir, modelName: "test-model", }, { name: "local-llama3", - modelDir: "testdata/local-llama3", + modelDir: localLlama3ModelDir, modelName: "local-llama3", }, } @@ -879,24 +880,24 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { } // Render the long conversation - reqLong := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: convertToPreprocessingChatMessages(longConversation), + reqLong := &preprocessing.ApplyChatTemplateRequest{ + Conversation: convertToPreprocessingConversation(longConversation), } - renderedPrompt, err := localTokenizer.RenderChatTemplate(tc.modelName, reqLong) + renderedPrompt, err := localTokenizer.ApplyChatTemplate(tc.modelName, reqLong) s.Require().NoError(err) s.Require().NotEmpty(renderedPrompt) s.Require().Greater(len(renderedPrompt), 1000, "Long conversation should produce substantial output") s.T().Logf("Long conversation rendered to %d characters", len(renderedPrompt)) // Tokenize - tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName) + tokens, offsets, err := localTokenizer.Encode(renderedPrompt, tc.modelName, true) s.Require().NoError(err) s.Require().NotEmpty(tokens) s.Require().Equal(len(tokens), len(offsets)) s.T().Logf("Long conversation produced %d tokens", len(tokens)) // Convert to block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName, true) s.Require().NotEmpty(requestKeys) s.T().Logf("Generated %d block keys from long conversation", len(requestKeys)) From 74289d80ad9daf25dac58bc0ea4bfc932f5fcd06 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Sat, 27 Dec 2025 13:02:33 +0000 Subject: [PATCH 02/12] lint Signed-off-by: HyunKyun Moon --- examples/kv_events/online/main.go | 2 +- pkg/kvcache/indexer.go | 3 +- pkg/tokenization/pool.go | 11 ++-- pkg/tokenization/pool_test.go | 9 +-- pkg/tokenization/tokenizer.go | 6 +- pkg/tokenization/tokenizer_test.go | 34 +++++++---- pkg/tokenization/uds_tokenizer.go | 2 +- tests/e2e/redis_mock/e2e_suite_test.go | 7 ++- tests/e2e/redis_mock/e2e_test.go | 83 ++++++++++++++------------ 9 files changed, 87 insertions(+), 70 deletions(-) diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index f141664aa..edf5cc469 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -327,7 +327,7 @@ func setupUnifiedHTTPEndpoints( } // Use KV-cache to score the rendered template - if len(renderedPrompt) == 0 { + if renderedPrompt == "" { http.Error(w, "rendered prompt is empty", http.StatusInternalServerError) return } diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index ebf5b5fe8..e266738a4 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -102,7 +102,8 @@ func NewKVCacheIndexer(ctx context.Context, config *Config) (*Indexer, error) { return nil, fmt.Errorf("failed to create KVBlockScorer: %w", err) } - tokenizersPool, err := tokenization.NewTokenizationPool(config.TokenizersPoolConfig, tokensIndexer) + tokenizersPool, err := tokenization.NewTokenizationPool(ctx, + config.TokenizersPoolConfig, tokensIndexer) if err != nil { return nil, fmt.Errorf("failed to create tokenizers pool: %w", err) } diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index f28750021..e6ebe2601 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -94,7 +94,7 @@ type Pool struct { // NewTokenizationPool initializes a TokenizationPool with the specified number // of workers and the provided Indexer. -func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, error) { +func NewTokenizationPool(ctx context.Context, config *Config, store prefixstore.Indexer) (*Pool, error) { if config == nil || config.ModelName == "" { return nil, fmt.Errorf("config and config.ModelName cannot be nil or empty") } @@ -108,7 +108,8 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro tokenizers := make([]Tokenizer, 0, 3) if config.LocalTokenizerConfig.IsEnabled() { - localTokenizer, err := NewCachedLocalTokenizer(config.ModelName, *config.LocalTokenizerConfig) + localTokenizer, err := NewCachedLocalTokenizer(ctx, + config.ModelName, *config.LocalTokenizerConfig) if err != nil { return nil, fmt.Errorf("failed to create local tokenizer: %w", err) } @@ -116,7 +117,8 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro } if config.UdsTokenizerConfig.IsEnabled() { - udsTokenizer, err := NewUdsTokenizer(config.UdsTokenizerConfig) + udsTokenizer, err := NewUdsTokenizer(ctx, + config.UdsTokenizerConfig) if err != nil { return nil, fmt.Errorf("failed to create UDS tokenizer: %w", err) } @@ -124,7 +126,8 @@ func NewTokenizationPool(config *Config, store prefixstore.Indexer) (*Pool, erro } if config.HFTokenizerConfig.IsEnabled() { - hfTokenizer, err := NewCachedHFTokenizer(config.ModelName, config.HFTokenizerConfig) + hfTokenizer, err := NewCachedHFTokenizer(ctx, + config.ModelName, config.HFTokenizerConfig) if err != nil { return nil, fmt.Errorf("failed to create HuggingFace tokenizer: %w", err) } diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index 2cfc15ad4..201bb462a 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -283,13 +283,13 @@ func TestPool_RunIntegration(t *testing.T) { MinPrefixOverlapRatio: defaultMinPrefixOverlapRatio, } - pool, err := NewTokenizationPool(config, mockIndexer) - require.NoError(t, err) - // Create context for the pool ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + pool, err := NewTokenizationPool(ctx, config, mockIndexer) + require.NoError(t, err) + for _, prompt := range prompts { pool.EnqueueTokenization(prompt) } @@ -336,7 +336,8 @@ func setupStressTest(b *testing.B, modelName string) *Pool { inMemoryIndexer, err := prefixstore.NewLRUTokenStore(nil) require.NoError(b, err) - pool, err := NewTokenizationPool(config, inMemoryIndexer) + pool, err := NewTokenizationPool(context.Background(), + config, inMemoryIndexer) require.NoError(b, err) return pool } diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index a5fdae49a..1389aec51 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -280,7 +280,7 @@ type LocalCachedTokenizer struct { // NewCachedHFTokenizer creates a new instance of CachedTokenizer downloading tokenizer configs from HuggingFace with // the provided configuration. -func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (*HFCachedTokenizer, error) { +func NewCachedHFTokenizer(ctx context.Context, modelID string, config *HFTokenizerConfig) (*HFCachedTokenizer, error) { tokenizerProvider := newHFTokenizerProvider(config) tokenizer, err := tokenizerProvider.get(modelID) if err != nil { @@ -293,7 +293,6 @@ func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (*HFCachedT return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - ctx := context.TODO() if err := chatTemplateRenderer.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ IsLocal: false, Model: modelID, @@ -322,7 +321,7 @@ func NewCachedHFTokenizer(modelID string, config *HFTokenizerConfig) (*HFCachedT // - Reducing startup latency by avoiding downloads // // The tokenizer is initialized for a specific model at creation time. -func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (*LocalCachedTokenizer, error) { +func NewCachedLocalTokenizer(ctx context.Context, modelName string, config LocalTokenizerConfig) (*LocalCachedTokenizer, error) { if err := discoverLocalTokenizerMap(&config); err != nil { return nil, fmt.Errorf("failed to discover local tokenizer map: %w", err) } @@ -346,7 +345,6 @@ func NewCachedLocalTokenizer(modelName string, config LocalTokenizerConfig) (*Lo return nil, fmt.Errorf("tokenizer for model %q not found", modelName) } - ctx := context.TODO() if err := chatTemplater.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ IsLocal: true, Model: path, diff --git a/pkg/tokenization/tokenizer_test.go b/pkg/tokenization/tokenizer_test.go index 48a7e247a..ece170b2c 100644 --- a/pkg/tokenization/tokenizer_test.go +++ b/pkg/tokenization/tokenizer_test.go @@ -18,6 +18,7 @@ limitations under the License. package tokenization import ( + "context" "fmt" "os" "path/filepath" @@ -61,7 +62,8 @@ func TestCachedHFTokenizer_Encode(t *testing.T) { config := &HFTokenizerConfig{ TokenizersCacheDir: t.TempDir(), } - tokenizer, err := NewCachedHFTokenizer(testModelName, config) + tokenizer, err := NewCachedHFTokenizer(context.Background(), + testModelName, config) require.NoError(t, err) require.NotNil(t, tokenizer) @@ -95,9 +97,10 @@ func TestCachedHFTokenizer_CacheTokenizer(t *testing.T) { t.Skip("Skipping tokenizer integration test in short mode") } - tokenizer, err := NewCachedHFTokenizer(testModelName, &HFTokenizerConfig{ - TokenizersCacheDir: t.TempDir(), - }) + tokenizer, err := NewCachedHFTokenizer(context.Background(), + testModelName, &HFTokenizerConfig{ + TokenizersCacheDir: t.TempDir(), + }) require.NoError(t, err) require.NotNil(t, tokenizer) @@ -122,9 +125,10 @@ func TestCachedHFTokenizer_InvalidModel(t *testing.T) { t.Skip("Skipping tokenizer integration test in short mode") } - tokenizer, err := NewCachedHFTokenizer("non-existent/model", &HFTokenizerConfig{ - TokenizersCacheDir: t.TempDir(), - }) + tokenizer, err := NewCachedHFTokenizer(context.Background(), + "non-existent/model", &HFTokenizerConfig{ + TokenizersCacheDir: t.TempDir(), + }) // Assert that an error occurred and tokenizer is nil assert.Error(t, err) @@ -138,7 +142,8 @@ func TestCachedLocalTokenizer_Encode(t *testing.T) { modelName: "testdata/test-model/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(modelName, config) + tokenizer, err := NewCachedLocalTokenizer(context.Background(), + modelName, config) require.NoError(t, err) require.NotNil(t, tokenizer) @@ -178,7 +183,8 @@ func TestCachedLocalTokenizer_InvalidModel(t *testing.T) { modelName: "testdata/test-model/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(invalidModelName, config) + tokenizer, err := NewCachedLocalTokenizer(context.Background(), + invalidModelName, config) require.Error(t, err) require.Nil(t, tokenizer) } @@ -190,7 +196,8 @@ func TestCachedLocalTokenizer_InvalidPath(t *testing.T) { modelName: "testdata/non-existent/tokenizer.json", }, } - tokenizer, err := NewCachedLocalTokenizer(modelName, config) + tokenizer, err := NewCachedLocalTokenizer(context.Background(), + modelName, config) require.Error(t, err) require.Nil(t, tokenizer) } @@ -201,9 +208,10 @@ func TestCompositeTokenizer_FallbackBehavior(t *testing.T) { } dummyTokenizer := &DummyTokenizer{returnError: true} - hfTokenizer, err := NewCachedHFTokenizer(testModelName, &HFTokenizerConfig{ - TokenizersCacheDir: t.TempDir(), - }) + hfTokenizer, err := NewCachedHFTokenizer(context.Background(), + testModelName, &HFTokenizerConfig{ + TokenizersCacheDir: t.TempDir(), + }) require.NoError(t, err) diff --git a/pkg/tokenization/uds_tokenizer.go b/pkg/tokenization/uds_tokenizer.go index 5d29eb5b1..83ddef82d 100644 --- a/pkg/tokenization/uds_tokenizer.go +++ b/pkg/tokenization/uds_tokenizer.go @@ -69,7 +69,7 @@ const ( ) // NewUdsTokenizer creates a new UDS-based tokenizer client with connection pooling. -func NewUdsTokenizer(config *UdsTokenizerConfig) (Tokenizer, error) { +func NewUdsTokenizer(_ context.Context, config *UdsTokenizerConfig) (Tokenizer, error) { dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, diff --git a/tests/e2e/redis_mock/e2e_suite_test.go b/tests/e2e/redis_mock/e2e_suite_test.go index 3f54fa21d..0632731a8 100644 --- a/tests/e2e/redis_mock/e2e_suite_test.go +++ b/tests/e2e/redis_mock/e2e_suite_test.go @@ -72,7 +72,8 @@ func (s *KVCacheSuite) SetupTest() { s.config.PrefixStoreConfig.BlockSize = 4 s.config.TokenProcessorConfig.BlockSize = 4 - hfTokenizer, err := tokenization.NewCachedHFTokenizer(defaultModelName, s.config.TokenizersPoolConfig.HFTokenizerConfig) + hfTokenizer, err := tokenization.NewCachedHFTokenizer(context.Background(), + defaultModelName, s.config.TokenizersPoolConfig.HFTokenizerConfig) s.Require().NoError(err) // Use composite tokenizer: try local first, then fall back to HF @@ -94,9 +95,9 @@ func (s *KVCacheSuite) SetupTest() { // //nolint:nonamedreturns // named returns keep gocritic unnamedResult satisfied while allowing compact return func (s *KVCacheSuite) promptToEngineAndRequestKeys( - prompt, model string, addSpecialToken bool, + prompt, model string, ) (engineKeys, requestKeys []kvblock.Key) { - tokens, _, err := s.tokenizer.Encode(prompt, model, addSpecialToken) + tokens, _, err := s.tokenizer.Encode(prompt, model, true) s.Require().NoError(err) requestKeys = s.tokensProcessor.TokensToKVBlockKeys(nil, tokens, model) diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index 94316fdd3..8b9b35eaf 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -18,6 +18,7 @@ limitations under the License. package e2e import ( + "context" "os" "path/filepath" "strings" @@ -116,7 +117,7 @@ func (s *KVCacheSuite) TestCacheHit() { prompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." fakePodList := []string{s.Pod1IP} - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName, true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName) s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) pods, err := s.indexer.GetPodScores(s.ctx, nil, prompt, defaultModelName, fakePodList) @@ -144,7 +145,7 @@ func (s *KVCacheSuite) TestPrefixReduction() { midPrompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." shortPrompt := "lorem ipsum dolor sit amet, consectetur adipiscing elit." - fullPromptEngineKeys, fullPromptRequestKeys := s.promptToEngineAndRequestKeys(fullPrompt, defaultModelName, true) + fullPromptEngineKeys, fullPromptRequestKeys := s.promptToEngineAndRequestKeys(fullPrompt, defaultModelName) fakePodList := []string{s.Pod1IP} // Test 1: Full prompt (no match expected) @@ -168,7 +169,7 @@ func (s *KVCacheSuite) TestPrefixReduction() { s.Len(pods, len(fakePodList), "expected pod scores length to match candidate pods") s.T().Logf("Received pod scores: %+v", pods) - _, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, defaultModelName, true) + _, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, defaultModelName) s.Equal(int(pods[s.Pod1IP]), len(shortPromptRequestKeys), "all short-prompt block keys should have been indexed") } @@ -188,7 +189,7 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.T().Logf("Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores") - shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName, true) + shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName) s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt @@ -198,7 +199,7 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.T().Logf("Received pod scores: %+v", pods) s.Equal(int(pods[s.Pod1IP]), len(shortPromptRequestKeys), "expected pod score to equal number of short prompt block keys") - midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName, true) + midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName) s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: full prompt @@ -228,7 +229,7 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.Empty(pods, "expected no pod scores") // Add entries to the index for the short prompt - shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName, true) + shortPromptEngineKeys, shortPromptRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, modelName) s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt (should return partial match if indexer picks it up) @@ -238,7 +239,7 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.True(len(pods) > 0, "expected at least one pod score for mid prompt") // Add entries to the index for the mid prompt - midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName, true) + midPromptEngineKeys, midPromptRequestKeys := s.promptToEngineAndRequestKeys(midPrompt, modelName) s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: long prompt (should return higher score) @@ -286,7 +287,7 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { s.Require().NotEmpty(flattenedPrompt, "Flattened prompt should not be empty") // Step 4: Use the flattened prompt for KV-cache lookup (similar to TestBasicE2E). - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") fakePodList := []string{s.Pod1IP} // First lookup - should return no scores initially. @@ -360,7 +361,7 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { s.Require().Greater(len(flattenedPrompt), 1000, "Long conversation should produce substantial output") // Step 4: Test KV-cache with the long flattened prompt. - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct") fakePodList := []string{s.Pod1IP} // First lookup. @@ -386,7 +387,7 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { // Create a local tokenizer using the testdata modelName := "test-model" - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(context.Background(), modelName, tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: "testdata/test-model/tokenizer.json", }, @@ -407,7 +408,7 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { s.T().Logf("Local tokenizer produced %d tokens for prompt", len(tokens)) // Convert tokens to KV block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName, true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName) // Add entries to the index - this verifies the local tokenizer produces valid block keys s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) @@ -448,7 +449,7 @@ func (s *KVCacheSuite) TestHFCacheStructureDiscoveryE2E() { AutoDiscoveryTokenizerFileName: "tokenizer.json", } - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, config) + localTokenizer, err := tokenization.NewCachedLocalTokenizer(context.Background(), modelName, config) s.Require().NoError(err) s.Require().NotNil(localTokenizer) @@ -465,7 +466,7 @@ func (s *KVCacheSuite) TestHFCacheStructureDiscoveryE2E() { s.T().Logf("HF cache auto-discovery produced %d tokens for model %q", len(tokens), modelName) // Convert tokens to KV block keys using promptToEngineAndRequestKeys with local tokenizer - engineKeys1, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName, true) + engineKeys1, requestKeys := s.promptToEngineAndRequestKeys(prompt, modelName) // Add entries to the index s.addEntriesToIndex(engineKeys1, requestKeys, fakePodList) @@ -505,11 +506,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ - ModelTokenizerMap: map[string]string{ - tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), - }, - }) + localTokenizer, err := tokenization.NewCachedLocalTokenizer( + context.Background(), tc.modelName, tokenization.LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), + }, + }) s.Require().NoError(err) s.Require().NotNil(localTokenizer) @@ -545,7 +547,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { s.T().Logf("Local tokenizer produced %d tokens from rendered chat template", len(tokens)) // Step 3: Convert tokens to KV block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName, true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName) s.T().Logf("Generated %d KV block keys from rendered conversation", len(requestKeys)) // Step 4: Add to index and verify retrieval (full KV-cache flow) @@ -600,11 +602,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ - ModelTokenizerMap: map[string]string{ - tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), - }, - }) + localTokenizer, err := tokenization.NewCachedLocalTokenizer( + context.Background(), tc.modelName, tokenization.LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), + }, + }) s.Require().NoError(err) s.SetTokenizer(localTokenizer, tc.modelName) @@ -627,7 +630,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { s.T().Logf("Short prompt length: %d chars", len(shortPrompt)) shortTokens, _, err := localTokenizer.Encode(shortPrompt, tc.modelName, true) s.Require().NoError(err) - shortEngineKeys, shortRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, tc.modelName, true) + shortEngineKeys, shortRequestKeys := s.promptToEngineAndRequestKeys(shortPrompt, tc.modelName) s.addEntriesToIndex(shortEngineKeys, shortRequestKeys, fakePodList) s.T().Logf("Short conversation: %d tokens, %d block keys", len(shortTokens), len(shortRequestKeys)) @@ -657,7 +660,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { extendedTokens, _, err := localTokenizer.Encode(extendedPrompt, tc.modelName, true) s.Require().NoError(err) - extendedEngineKeys, extendedRequestKeys := s.promptToEngineAndRequestKeys(extendedPrompt, tc.modelName, true) + extendedEngineKeys, extendedRequestKeys := s.promptToEngineAndRequestKeys(extendedPrompt, tc.modelName) s.T().Logf("Extended conversation: %d tokens, %d block keys", len(extendedTokens), len(extendedRequestKeys)) // Some tokenizers use fixed-length encoding with padding (e.g., 512 tokens) @@ -729,11 +732,12 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { s.Require().FileExists(filepath.Join(testModelDir, "config.json"), "config.json should exist") s.Require().FileExists(filepath.Join(testModelDir, "tokenizer.json"), "tokenizer.json should exist") - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ - ModelTokenizerMap: map[string]string{ - tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), - }, - }) + localTokenizer, err := tokenization.NewCachedLocalTokenizer( + context.Background(), tc.modelName, tokenization.LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), + }, + }) s.Require().NoError(err) s.SetTokenizer(localTokenizer, tc.modelName) @@ -757,7 +761,7 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { s.T().Logf("Local tokenizer: rendered=%d chars, tokens=%d", len(localRendered), len(localTokens)) // Add to index and verify with GetPodScores - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(localRendered, tc.modelName, true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(localRendered, tc.modelName) fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) @@ -793,7 +797,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { testModelDir, err := filepath.Abs(localTestModelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(modelName, tokenization.LocalTokenizerConfig{ + localTokenizer, err := tokenization.NewCachedLocalTokenizer(context.Background(), modelName, tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: filepath.Join(testModelDir, "tokenizer.json"), }, @@ -855,11 +859,12 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { testModelDir, err := filepath.Abs(tc.modelDir) s.Require().NoError(err) - localTokenizer, err := tokenization.NewCachedLocalTokenizer(tc.modelName, tokenization.LocalTokenizerConfig{ - ModelTokenizerMap: map[string]string{ - tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), - }, - }) + localTokenizer, err := tokenization.NewCachedLocalTokenizer( + context.Background(), tc.modelName, tokenization.LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + tc.modelName: filepath.Join(testModelDir, "tokenizer.json"), + }, + }) s.Require().NoError(err) s.SetTokenizer(localTokenizer, tc.modelName) @@ -897,7 +902,7 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { s.T().Logf("Long conversation produced %d tokens", len(tokens)) // Convert to block keys - engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName, true) + engineKeys, requestKeys := s.promptToEngineAndRequestKeys(renderedPrompt, tc.modelName) s.Require().NotEmpty(requestKeys) s.T().Logf("Generated %d block keys from long conversation", len(requestKeys)) From d494e14c7d06727d4ce10ffa178ba123e5c8d5e0 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Sat, 27 Dec 2025 14:15:19 +0000 Subject: [PATCH 03/12] apply copilot review Signed-off-by: HyunKyun Moon --- .../chat_completions/cgo_functions.c | 15 +-- .../chat_completions/cgo_functions.go | 4 +- .../chat_completions/cgo_functions_test.go | 94 +++++++++++-------- .../chat_completions/tokenizer_wrapper.py | 12 +-- pkg/tokenization/tokenizer.go | 1 + tests/e2e/redis_mock/e2e_test.go | 4 +- 6 files changed, 74 insertions(+), 56 deletions(-) diff --git a/pkg/preprocessing/chat_completions/cgo_functions.c b/pkg/preprocessing/chat_completions/cgo_functions.c index b0b27a82d..38bab5acd 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.c +++ b/pkg/preprocessing/chat_completions/cgo_functions.c @@ -243,14 +243,14 @@ bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { if (!g_python_initialized) { printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Python not initialized\n"); fflush(stdout); - return NULL; + return false; } // Validate cached function if (!g_load_tokenizer_with_cache_func) { printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is NULL\n"); fflush(stdout); - return NULL; + return false; } // Validate that the cached function is still a valid Python object @@ -258,14 +258,14 @@ bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { if (!PyCallable_Check(g_load_tokenizer_with_cache_func)) { printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is not callable (corrupted?)\n"); fflush(stdout); - return NULL; + return false; } // Validate input if (!json_request) { printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Input is NULL\n"); fflush(stdout); - return NULL; + return false; } // Acquire GIL for Python operations @@ -277,7 +277,7 @@ bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Failed to create Python string\n"); fflush(stdout); PyGILState_Release(gil_state); - return NULL; + return false; } // Create arguments tuple @@ -287,7 +287,7 @@ bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { fflush(stdout); Py_DECREF(py_json); PyGILState_Release(gil_state); - return NULL; + return false; } // Call the cached function @@ -305,6 +305,9 @@ bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { fflush(stderr); cresult = false; } + else { + Py_DECREF(py_result); + } // Release GIL PyGILState_Release(gil_state); diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index 57b7c09aa..d29ec81d9 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -51,7 +51,7 @@ type Conversation struct { type ApplyChatTemplateRequest struct { // The Python wrapper will handle converting this to a batched list if needed. LoadTokenizerWithCacheRequest LoadTokenizerWithCacheRequest `json:"load_tokenizer_with_cache_request,omitempty"` - Conversation []Conversation `json:"conversation"` + Conversation [][]Conversation `json:"conversation"` Tools []interface{} `json:"tools,omitempty"` Documents []interface{} `json:"documents,omitempty"` ChatTemplate string `json:"chat_template,omitempty"` @@ -110,7 +110,7 @@ func (w *ChatTemplatingProcessor) Finalize() { C.Py_FinalizeGo() } -// Load Tokenzier. +// LoadTokenizerWithCache loads a tokenizer with caching using the cached Python function. func (w *ChatTemplatingProcessor) LoadTokenizerWithCache( ctx context.Context, req *LoadTokenizerWithCacheRequest, diff --git a/pkg/preprocessing/chat_completions/cgo_functions_test.go b/pkg/preprocessing/chat_completions/cgo_functions_test.go index d626c0a25..673e67c39 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions_test.go +++ b/pkg/preprocessing/chat_completions/cgo_functions_test.go @@ -135,31 +135,37 @@ func TestApplyChatTemplate(t *testing.T) { tests := []struct { name string template string - messages []preprocessing.Conversation + messages [][]preprocessing.Conversation }{ { name: "Simple ChatTemplate", template: simpleTemplate, - messages: []preprocessing.Conversation{ - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "Hi there!"}, + messages: [][]preprocessing.Conversation{ + { + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }, }, }, { name: "Complex ChatTemplate with System Message", template: complexTemplate, - messages: []preprocessing.Conversation{ - {Role: "system", Content: "You are a helpful AI assistant."}, - {Role: "user", Content: "What is the weather like?"}, - {Role: "assistant", Content: "I don't have access to real-time weather data."}, + messages: [][]preprocessing.Conversation{ + { + {Role: "system", Content: "You are a helpful AI assistant."}, + {Role: "user", Content: "What is the weather like?"}, + {Role: "assistant", Content: "I don't have access to real-time weather data."}, + }, }, }, { name: "Complex ChatTemplate without System Message", template: complexTemplate, - messages: []preprocessing.Conversation{ - {Role: "user", Content: "Tell me a joke"}, - {Role: "assistant", Content: "Why don't scientists trust atoms? Because they make up everything!"}, + messages: [][]preprocessing.Conversation{ + { + {Role: "user", Content: "Tell me a joke"}, + {Role: "assistant", Content: "Why don't scientists trust atoms? Because they make up everything!"}, + }, }, }, } @@ -188,7 +194,7 @@ func TestApplyChatTemplate(t *testing.T) { t.Logf("ChatTemplate: %s, Duration: %v, Rendered length: %d", tt.name, duration, len(rendered)) // Verify rendered content - for _, message := range tt.messages { + for _, message := range tt.messages[0] { // For complex templates, the role might not be explicitly shown in output // but the content should always be present assert.Contains(t, rendered, message.Content, "Rendered content should contain message content") @@ -249,47 +255,55 @@ func TestChatCompletionsIntegration(t *testing.T) { tests := []struct { name string modelName string - conversation []preprocessing.Conversation + conversation [][]preprocessing.Conversation description string }{ { name: "Simple Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.Conversation{ - {Role: "user", Content: "What is the capital of France?"}, - {Role: "assistant", Content: "The capital of France is Paris."}, + conversation: [][]preprocessing.Conversation{ + { + {Role: "user", Content: "What is the capital of France?"}, + {Role: "assistant", Content: "The capital of France is Paris."}, + }, }, description: "Basic question and answer conversation", }, { name: "Multi-turn Conversation", modelName: "microsoft/DialoGPT-medium", - conversation: []preprocessing.Conversation{ - {Role: "user", Content: "Hello, how are you?"}, - {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, - {Role: "user", Content: "Can you tell me about machine learning?"}, - {Role: "assistant", Content: "Machine learning is a subset of artificial intelligence " + - "that enables computers to learn and make decisions from data without being explicitly programmed."}, + conversation: [][]preprocessing.Conversation{ + { + {Role: "user", Content: "Hello, how are you?"}, + {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + {Role: "user", Content: "Can you tell me about machine learning?"}, + {Role: "assistant", Content: "Machine learning is a subset of artificial intelligence " + + "that enables computers to learn and make decisions from data without being explicitly programmed."}, + }, }, description: "Multi-turn conversation with follow-up questions", }, { name: "System Message Conversation", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.Conversation{ - {Role: "system", Content: "You are a helpful AI assistant specialized in coding."}, - {Role: "user", Content: "Write a Python function to calculate fibonacci numbers."}, - {Role: "assistant", Content: "Here's a Python function to calculate fibonacci numbers:\n" + - "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"}, + conversation: [][]preprocessing.Conversation{ + { + {Role: "system", Content: "You are a helpful AI assistant specialized in coding."}, + {Role: "user", Content: "Write a Python function to calculate fibonacci numbers."}, + {Role: "assistant", Content: "Here's a Python function to calculate fibonacci numbers:\n" + + "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"}, + }, }, description: "Conversation with system message and code generation", }, { name: "Simple Conversation (Repeated)", modelName: "ibm-granite/granite-3.3-8b-instruct", - conversation: []preprocessing.Conversation{ - {Role: "user", Content: "What is the capital of France?"}, - {Role: "assistant", Content: "The capital of France is Paris."}, + conversation: [][]preprocessing.Conversation{ + { + {Role: "user", Content: "What is the capital of France?"}, + {Role: "assistant", Content: "The capital of France is Paris."}, + }, }, description: "Basic question and answer conversation (repeated to test render caching)", }, @@ -315,7 +329,7 @@ func TestChatCompletionsIntegration(t *testing.T) { assert.NotEmpty(t, rendered, "Rendered chat should not be empty") // Verify all conversation messages are present in the rendered output - for _, message := range tt.conversation { + for _, message := range tt.conversation[0] { assert.Contains(t, rendered, message.Content, "Rendered content should contain message content") } @@ -360,7 +374,7 @@ func TestLongChatCompletions(t *testing.T) { require.NoError(t, err, "Failed to clear caches") // Create a long conversation - longConversation := []preprocessing.Conversation{ + longConversation := [][]preprocessing.Conversation{{ {Role: "system", Content: "You are an expert software engineer with deep knowledge of Go, Python, " + "and system design. " + "Provide detailed, accurate responses."}, @@ -384,7 +398,7 @@ func TestLongChatCompletions(t *testing.T) { "involves logging all mutations to disk before applying them to memory. " + "For recovery, you can replay the log to reconstruct the cache state. You might also want to " + "implement periodic snapshots for faster recovery."}, - } + }} modelName := "ibm-granite/granite-3.3-8b-instruct" @@ -410,7 +424,7 @@ func TestLongChatCompletions(t *testing.T) { t.Logf("ChatTemplate Long conversation render: %v", renderDuration) // Verify all messages are present - for _, message := range longConversation { + for _, message := range longConversation[0] { assert.Contains(t, rendered, message.Content, "All message content should be present in rendered output") } @@ -472,10 +486,10 @@ func BenchmarkApplyChatTemplate(b *testing.B) { LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ Model: "ibm-granite/granite-3.3-8b-instruct", }, - Conversation: []preprocessing.Conversation{ + Conversation: [][]preprocessing.Conversation{{ {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, - }, + }}, } // Track first iteration time and total time @@ -565,10 +579,10 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ Model: modelName, }, - Conversation: []preprocessing.Conversation{ + Conversation: [][]preprocessing.Conversation{{ {Role: "user", Content: "What is the weather in Paris?"}, {Role: "assistant", Content: "Let me check that for you."}, - }, + }}, Documents: []interface{}{ map[string]interface{}{ "title": "Paris Weather Report", @@ -663,10 +677,10 @@ func TestApplyChatTemplateWithLocalTemplate(t *testing.T) { IsLocal: true, }, - Conversation: []preprocessing.Conversation{ + Conversation: [][]preprocessing.Conversation{{ {Role: "user", Content: "Hello from local tokenizer!"}, {Role: "assistant", Content: "Hi! I'm using a locally loaded template."}, - }, + }}, } rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index 4ccd022f9..8f7bf7314 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -35,7 +35,7 @@ def clear_caches(): def apply_chat_template(request_json): """ - Render a chat template using the transformers library. + Render a chat template using the vllm library. This function is aligned with the Go cgo_functions.go structs. Args: @@ -68,11 +68,11 @@ def apply_chat_template(request_json): request.update(template_vars) request["tokenize"] = False - return tokenizer.apply_chat_template(**request) + return tokenizer.apply_chat_template(**request)[0] except Exception as e: raise RuntimeError(f"Error applying chat template: {e}") from e - + def load_tokenizer_with_cache(request_json): """ Initialize and cache the tokenizer based on the request. @@ -102,8 +102,8 @@ def load_tokenizer_with_cache(request_json): cache_key = f"{model_name}:{revision or 'main'}:{is_local}" tokenizer = _tokenizer_cache.get(cache_key) - if not tokenizer is None: - return tokenizer + if tokenizer is not None: + return tokenizer os.environ["HF_TOKEN"] = token tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir) _tokenizer_cache[cache_key] = tokenizer @@ -145,7 +145,7 @@ def main(): response = apply_chat_template(request_str) print("Rendered chat:") - print(response[0]) + print(response) except Exception as e: print(f"Error: {e}") diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 1389aec51..7e7a93b3b 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -389,6 +389,7 @@ func (t *HFCachedTokenizer) ApplyChatTemplate( req.LoadTokenizerWithCacheRequest.IsLocal = false req.LoadTokenizerWithCacheRequest.DownloadDir = t.hfTokenizerConfig.TokenizersCacheDir req.LoadTokenizerWithCacheRequest.Token = t.hfTokenizerConfig.HuggingFaceToken + req.LoadTokenizerWithCacheRequest.Model = modelName res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) if err != nil { return "", fmt.Errorf("failed to render chat template: %w", err) diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index 8b9b35eaf..cba6619dd 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -61,7 +61,7 @@ type GetChatTemplateRequest struct { } // convertToPreprocessingConversation converts e2e ChatMessage to preprocessing Conversation. -func convertToPreprocessingConversation(messages []ChatMessage) []preprocessing.Conversation { +func convertToPreprocessingConversation(messages []ChatMessage) [][]preprocessing.Conversation { result := make([]preprocessing.Conversation, len(messages)) for i, msg := range messages { result[i] = preprocessing.Conversation{ @@ -69,7 +69,7 @@ func convertToPreprocessingConversation(messages []ChatMessage) []preprocessing. Content: msg.Content, } } - return result + return [][]preprocessing.Conversation{result} } // MockChatTemplateWrapper provides a mock implementation for testing. From a69810e5ed671e87e2679b8d4ac67d40d77e11b2 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Sun, 28 Dec 2025 02:40:06 +0900 Subject: [PATCH 04/12] add example_usage Signed-off-by: HyunKyun Moon --- .../chat_completions/tokenizer_wrapper.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index 8f7bf7314..ccd8af75f 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -40,10 +40,12 @@ def apply_chat_template(request_json): Args: request_json (str): JSON string containing the request parameters: - - is_local (bool, optional): Whether the model is local. - - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). - - revision (str, optional): Model revision. - - token (str, optional): Hugging Face token for private models. + - load_tokenizer_with_cache_request (dict): Parameters for loading the tokenizer: + - is_local (bool, optional): Whether the model is local. + - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). + - revision (str, optional): Model revision. + - token (str, optional): Hugging Face token for private models. + - download_dir (str, optional): Directory to download the model. - conversation (list): List of conversation lists - chat_template (str, optional): The template to use - tools (list, optional): Tool schemas @@ -60,7 +62,7 @@ def apply_chat_template(request_json): try: # Parse the JSON request request = json.loads(request_json) - tokenizer_request = request.pop("load_tokenizer_with_cache_request", request) + tokenizer_request = request.pop("load_tokenizer_with_cache_request") tokenizer = load_tokenizer_with_cache(json.dumps(tokenizer_request)) # Get template_vars and spread them as individual arguments @@ -111,6 +113,17 @@ def load_tokenizer_with_cache(request_json): except Exception as e: raise RuntimeError(f"Error initializing tokenizer: {e}") from e +def example_usage(): + """Example usage of apply_chat_template function.""" + request_str = json.dumps({ + "load_tokenizer_with_cache_request": { + "is_local": False, + "model": "ibm-granite/granite-3.3-8b-instruct", + }, + "conversation": [ [{"role": "system", "content": "You are a helpful assistant."}] , [{"role": "user", "content": "who are you?"}] ], + }) + print(apply_chat_template(request_str)) + def main(): """Example usage and testing function.""" @@ -138,7 +151,10 @@ def main(): try: # Construct the request JSON string similar to how Go would request_str = json.dumps({ - "model": "facebook/opt-125m", + "load_tokenizer_with_cache_request": { + "is_local": True, + "model": "facebook/opt-125m", + }, "conversation": [conversation], "chat_template": chat_template }) From 2fa20430df063d0283665fd7899b78a48cdab248 Mon Sep 17 00:00:00 2001 From: Hyunkyun Moon Date: Sat, 3 Jan 2026 01:47:36 +0900 Subject: [PATCH 05/12] Update pkg/preprocessing/chat_completions/requirements.txt Co-authored-by: Edoardo Vacchi Signed-off-by: Hyunkyun Moon --- pkg/preprocessing/chat_completions/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/preprocessing/chat_completions/requirements.txt b/pkg/preprocessing/chat_completions/requirements.txt index 6f48e6c24..26c952f34 100644 --- a/pkg/preprocessing/chat_completions/requirements.txt +++ b/pkg/preprocessing/chat_completions/requirements.txt @@ -1,3 +1,4 @@ --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple -vllm-cpu>=0.11.0 \ No newline at end of file +vllm-cpu>=0.11.0; sys_platform != 'darwin' +vllm @ git+https://github.com/vllm-project/vllm.git@v0.11.0; sys_platform == 'darwin' \ No newline at end of file From 029be0944b3d805387c96128147a748c62ee5c1c Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Mon, 5 Jan 2026 06:30:30 +0000 Subject: [PATCH 06/12] request with key Signed-off-by: HyunKyun Moon --- .gitignore | 5 +- .../chat_completions/cgo_functions.c | 106 +++++----- .../chat_completions/cgo_functions.go | 44 ++-- .../chat_completions/cgo_functions.h | 8 +- .../chat_completions/cgo_functions_test.go | 200 ++++++++++-------- .../chat_completions/tokenizer_wrapper.py | 83 +++++--- pkg/tokenization/tokenizer.go | 43 ++-- services/uds_tokenizer/server.py | 190 ++++++++++------- tests/e2e/redis_mock/e2e_test.go | 23 +- 9 files changed, 386 insertions(+), 316 deletions(-) diff --git a/.gitignore b/.gitignore index 5a2cede74..f1a3a0a2c 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,7 @@ _cgo_* /hack/tools # Tokenizer binaries -/lib \ No newline at end of file +/lib + +# uds tokenizer default model path +services/uds_tokenizer/models \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/cgo_functions.c b/pkg/preprocessing/chat_completions/cgo_functions.c index 38bab5acd..16d71b96a 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.c +++ b/pkg/preprocessing/chat_completions/cgo_functions.c @@ -15,13 +15,12 @@ limitations under the License. */ #include // for getpid() and usleep() -#include #include "cgo_functions.h" // Global variables for caching PyObject* g_chat_template_module = NULL; -PyObject* g_load_tokenizer_with_cache_func = NULL; +PyObject* g_get_or_create_tokenizer_key_func = NULL; PyObject* g_apply_chat_template_func = NULL; int g_initialized = 0; int g_python_initialized = 0; @@ -102,9 +101,9 @@ void Py_FinalizeGo() { g_finalized = 1; // Clean up module references safely - if (g_load_tokenizer_with_cache_func) { - Py_DECREF(g_load_tokenizer_with_cache_func); - g_load_tokenizer_with_cache_func = NULL; + if (g_get_or_create_tokenizer_key_func) { + Py_DECREF(g_get_or_create_tokenizer_key_func); + g_get_or_create_tokenizer_key_func = NULL; } if (g_apply_chat_template_func) { @@ -202,15 +201,15 @@ int Py_InitChatTemplateModule() { return -1; } - // Get the load_tokenizer_with_cache function - g_load_tokenizer_with_cache_func = PyDict_GetItemString(module_dict, "load_tokenizer_with_cache"); - if (!g_load_tokenizer_with_cache_func || !PyCallable_Check(g_load_tokenizer_with_cache_func)) { - printf("[C] Py_InitChatTemplateModule ERROR - load_tokenizer_with_cache function not found or not callable\n"); + // Get the get_or_create_tokenizer_key function + g_get_or_create_tokenizer_key_func = PyDict_GetItemString(module_dict, "get_or_create_tokenizer_key"); + if (!g_get_or_create_tokenizer_key_func || !PyCallable_Check(g_get_or_create_tokenizer_key_func)) { + printf("[C] Py_InitChatTemplateModule ERROR - get_or_create_tokenizer_key function not found or not callable\n"); PyGILState_Release(gil_state); PyThread_release_lock(g_init_lock); return -1; } - Py_INCREF(g_load_tokenizer_with_cache_func); // Keep a reference + Py_INCREF(g_get_or_create_tokenizer_key_func); // Keep a reference // Get the apply_chat_template function g_apply_chat_template_func = PyDict_GetItemString(module_dict, "apply_chat_template"); @@ -232,81 +231,80 @@ int Py_InitChatTemplateModule() { -// Call the cached load_tokenizer_with_cache function -bool Py_CallLoadTokenizerWithCache(const char* json_request) { - return Py_CallLoadTokenizerWithCacheInternal(json_request); -} - -// Internal function that does the actual work -bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request) { - // Check if Python interpreter is initialized - if (!g_python_initialized) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Python not initialized\n"); - fflush(stdout); - return false; +// Call the cached get_or_create_tokenizer_key function +char* Py_CallGetOrCreateTokenizerKey(const char* json_request) { + // Try direct call first (fast path) + char* result = Py_CallGetOrCreateTokenizerKeyInternal(json_request); + if (result != NULL) { + return result; // Success on first try } - // Validate cached function - if (!g_load_tokenizer_with_cache_func) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is NULL\n"); - fflush(stdout); - return false; - } + // If failed, just return NULL (no retry, no reload) + return NULL; +} - // Validate that the cached function is still a valid Python object - fflush(stdout); - if (!PyCallable_Check(g_load_tokenizer_with_cache_func)) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Cached function is not callable (corrupted?)\n"); +// Internal function that does the actual work +char* Py_CallGetOrCreateTokenizerKeyInternal(const char* json_request) { + // Check if Python interpreter is still valid + if (!Py_IsInitialized()) { + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Python interpreter not initialized\n"); fflush(stdout); - return false; + return NULL; } - // Validate input + // Simple validation if (!json_request) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Input is NULL\n"); + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Input is NULL\n"); fflush(stdout); - return false; + return NULL; } // Acquire GIL for Python operations PyGILState_STATE gil_state = PyGILState_Ensure(); - // Create Python string from JSON request PyObject* py_json = PyUnicode_FromString(json_request); if (!py_json) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Failed to create Python string\n"); + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Failed to create Python string\n"); fflush(stdout); PyGILState_Release(gil_state); - return false; + return NULL; } // Create arguments tuple PyObject* args = PyTuple_Pack(1, py_json); if (!args) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Failed to create args tuple\n"); + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Failed to create args tuple\n"); fflush(stdout); Py_DECREF(py_json); PyGILState_Release(gil_state); - return false; + return NULL; } // Call the cached function - PyObject* py_result = PyObject_CallObject(g_load_tokenizer_with_cache_func, args); + PyObject* py_result = PyObject_CallObject(g_get_or_create_tokenizer_key_func, args); // Clean up args Py_DECREF(args); Py_DECREF(py_json); - bool cresult = true; - if (!py_result) { - printf("[C] Py_CallLoadTokenizerWithCacheInternal ERROR - Python function returned NULL\n"); + char* cresult = NULL; + if (py_result) { + // Convert to C string + const char* s = PyUnicode_AsUTF8(py_result); + if (s) { + cresult = strdup(s); + } + else { + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Failed to convert result to C string\n"); + fflush(stdout); + } + Py_DECREF(py_result); + } + else { + printf("[C] Py_CallGetOrCreateTokenizerKeyInternal ERROR - Python function returned NULL\n"); fflush(stdout); PyErr_Print(); fflush(stderr); - cresult = false; - } - else { - Py_DECREF(py_result); } // Release GIL @@ -442,10 +440,10 @@ char* Py_ClearCaches() { void Py_CleanupChatTemplateModule() { if (g_initialized && Py_IsInitialized()) { PyGILState_STATE state = PyGILState_Ensure(); - Py_XDECREF(g_load_tokenizer_with_cache_func); + Py_XDECREF(g_get_or_create_tokenizer_key_func); Py_XDECREF(g_apply_chat_template_func); Py_XDECREF(g_chat_template_module); - g_load_tokenizer_with_cache_func = NULL; + g_get_or_create_tokenizer_key_func = NULL; g_apply_chat_template_func = NULL; g_chat_template_module = NULL; g_initialized = 0; @@ -461,9 +459,9 @@ int Py_ReinitializeGo() { g_process_initialized = 0; // Clean up cached objects - if (g_load_tokenizer_with_cache_func) { - Py_DECREF(g_load_tokenizer_with_cache_func); - g_load_tokenizer_with_cache_func = NULL; + if (g_get_or_create_tokenizer_key_func) { + Py_DECREF(g_get_or_create_tokenizer_key_func); + g_get_or_create_tokenizer_key_func = NULL; } if (g_apply_chat_template_func) { Py_DECREF(g_apply_chat_template_func); diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index d29ec81d9..30bef3f0f 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -33,7 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -type LoadTokenizerWithCacheRequest struct { +type GetOrCreateTokenizerKeyRequest struct { IsLocal bool `json:"is_local,omitempty"` DownloadDir string `json:"download_dir,omitempty"` Model string `json:"model"` @@ -50,15 +50,15 @@ type Conversation struct { // ApplyChatTemplateRequest represents the request to render a chat template. type ApplyChatTemplateRequest struct { // The Python wrapper will handle converting this to a batched list if needed. - LoadTokenizerWithCacheRequest LoadTokenizerWithCacheRequest `json:"load_tokenizer_with_cache_request,omitempty"` - Conversation [][]Conversation `json:"conversation"` - Tools []interface{} `json:"tools,omitempty"` - Documents []interface{} `json:"documents,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"` - ContinueFinalMessage bool `json:"continue_final_message,omitempty"` - AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"` - ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` + Key string `json:"key"` + Conversation [][]Conversation `json:"conversation"` + Tools []interface{} `json:"tools,omitempty"` + Documents []interface{} `json:"documents,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"` + ContinueFinalMessage bool `json:"continue_final_message,omitempty"` + AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"` + ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` } // DeepCopy creates a deep copy of the ApplyChatTemplateRequest. @@ -110,29 +110,31 @@ func (w *ChatTemplatingProcessor) Finalize() { C.Py_FinalizeGo() } -// LoadTokenizerWithCache loads a tokenizer with caching using the cached Python function. -func (w *ChatTemplatingProcessor) LoadTokenizerWithCache( +// GetOrCreateTokenizerKey returns the cache key for the tokenizer specified in the request. +func (w *ChatTemplatingProcessor) GetOrCreateTokenizerKey( ctx context.Context, - req *LoadTokenizerWithCacheRequest, -) error { + req *GetOrCreateTokenizerKeyRequest, +) (string, error) { traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("LoadTokenizer") if req == nil { traceLogger.Error(nil, "Received nil request") - return fmt.Errorf("received nil request") + return "", fmt.Errorf("received nil request") } // Convert request to JSON reqJSON, err := json.Marshal(req) if err != nil { traceLogger.Error(err, "Failed to marshal request") - return fmt.Errorf("failed to marshal request: %w", err) + return "", fmt.Errorf("failed to marshal request: %w", err) } // Call the cached Python function - cResult := C.Py_CallLoadTokenizerWithCache(C.CString(string(reqJSON))) - if !cResult { - traceLogger.Error(nil, "C function returned false") - return fmt.Errorf("python load tokenizer failed") + cResult := C.Py_CallGetOrCreateTokenizerKey(C.CString(string(reqJSON))) + if cResult == nil { + traceLogger.Error(nil, "C function returned nil") + return "", fmt.Errorf("python get_or_create_tokenizer_key failed") } - return nil + defer C.free(unsafe.Pointer(cResult)) + + return C.GoString(cResult), nil } // ApplyChatTemplate renders a chat template using the cached Python function. diff --git a/pkg/preprocessing/chat_completions/cgo_functions.h b/pkg/preprocessing/chat_completions/cgo_functions.h index 45ba7bf2a..91ef0e49b 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.h +++ b/pkg/preprocessing/chat_completions/cgo_functions.h @@ -49,17 +49,17 @@ const char* PyUnicode_AsGoString(PyObject* obj); // Global variables to hold cached module and functions extern PyObject* g_chat_template_module; -extern PyObject* g_load_tokenizer_with_cache_func; +extern PyObject* g_get_or_create_tokenizer_key_func; extern PyObject* g_apply_chat_template_func; // Initialize the cached module and functions (call once at startup) int Py_InitChatTemplateModule(); -// Call the cached load_tokenizer_with_cache function -bool Py_CallLoadTokenizerWithCache(const char* json_request); +// Call the cached get_or_create_tokenizer_key function +char* Py_CallGetOrCreateTokenizerKey(const char* json_request); // Internal function that does the actual work -bool Py_CallLoadTokenizerWithCacheInternal(const char* json_request); +char* Py_CallGetOrCreateTokenizerKeyInternal(const char* json_request); // Call the cached apply_chat_template function char* Py_CallApplyChatTemplate(const char* json_request); diff --git a/pkg/preprocessing/chat_completions/cgo_functions_test.go b/pkg/preprocessing/chat_completions/cgo_functions_test.go index 673e67c39..db0fc2174 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions_test.go +++ b/pkg/preprocessing/chat_completions/cgo_functions_test.go @@ -49,8 +49,8 @@ func getGlobalWrapper() *preprocessing.ChatTemplatingProcessor { return globalWrapper } -// TestLoadTokenizerWithCache tests the load_tokenizer_with_cache function. -func TestLoadTokenizerWithCache(t *testing.T) { +// TestGetOrCreateTokenizerKey tests the get_or_create_tokenizer_key function. +func TestGetOrCreateTokenizerKey(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements @@ -78,7 +78,7 @@ func TestLoadTokenizerWithCache(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: tt.modelName, Revision: tt.revision, Token: tt.token, @@ -86,14 +86,14 @@ func TestLoadTokenizerWithCache(t *testing.T) { // Profile the function call start := time.Now() - err := wrapper.LoadTokenizerWithCache(context.Background(), request) + _, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) duration := time.Since(start) // Log performance t.Logf("Model: %s, Duration: %v", tt.modelName, duration) if tt.expectTemplate { // Models that should have templates - require.NoError(t, err, "LoadTokenizerWithCache should not return an error") + require.NoError(t, err, "GetOrCreateTokenizerKey should not return an error") } else { // Models that don't have chat templates if err != nil { @@ -172,18 +172,20 @@ func TestApplyChatTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: "ibm-granite/granite-3.3-8b-instruct", - IsLocal: true, - }, - Conversation: tt.messages, - ChatTemplate: tt.template, - } + ctx := context.Background() + key, err := wrapper.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: "ibm-granite/granite-3.3-8b-instruct", + IsLocal: true, + }) + require.NoError(t, err, "Failed to get tokenizer key") // Profile the function call start := time.Now() - rendered, err := wrapper.ApplyChatTemplate(context.Background(), request) + rendered, err := wrapper.ApplyChatTemplate(ctx, &preprocessing.ApplyChatTemplateRequest{ + Key: key, + Conversation: tt.messages, + ChatTemplate: tt.template, + }) duration := time.Since(start) // Assertions @@ -208,8 +210,8 @@ func TestApplyChatTemplate(t *testing.T) { } } -// TestLoadTokenizerCaching tests the caching functionality. -func TestLoadTokenizerCaching(t *testing.T) { +// TestGetOrCreateTokenizerKeyCaching tests the caching functionality. +func TestGetOrCreateTokenizerKeyCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear all caches to ensure we start with a clean state @@ -217,7 +219,7 @@ func TestLoadTokenizerCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") modelName := "ibm-granite/granite-3.3-8b-instruct" - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: modelName, IsLocal: false, } @@ -225,17 +227,20 @@ func TestLoadTokenizerCaching(t *testing.T) { // First call - should be cache miss t.Log("=== First call (Cache MISS) ===") start := time.Now() - err = wrapper.LoadTokenizerWithCache(context.Background(), request) + key1, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - should be cache hit t.Log("=== Second call (Cache HIT) ===") start = time.Now() - err = wrapper.LoadTokenizerWithCache(context.Background(), request) + key2, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") + // Verify that both calls returned the same key + assert.Equal(t, key1, key2, "Both calls should return the same tokenizer key") + // Verify performance improvement t.Logf("First call duration: %v, Second call duration: %v, Speedup: %.1fx", duration1, duration2, float64(duration1)/float64(duration2)) @@ -313,15 +318,20 @@ func TestChatCompletionsIntegration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Logf("Testing: %s - %s", tt.name, tt.description) - // Step 1: Render the conversation using the template + // step 1: get tokenizer key + ctx := context.Background() + key, err := wrapper.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: tt.modelName, + IsLocal: false, + }) + require.NoError(t, err, "Failed to get tokenizer key") + + // Step 2: Render the conversation with tokenizer key start := time.Now() - renderRequest := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: tt.modelName, - }, + rendered, err := wrapper.ApplyChatTemplate(context.Background(), &preprocessing.ApplyChatTemplateRequest{ + Key: key, Conversation: tt.conversation, - } - rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) + }) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render chat template") @@ -369,8 +379,10 @@ func TestVLLMValidation(t *testing.T) { func TestLongChatCompletions(t *testing.T) { wrapper := getGlobalWrapper() + ctx := context.Background() + // Clear caches to ensure accurate timing measurements - err := preprocessing.ClearCaches(context.Background()) + err := preprocessing.ClearCaches(ctx) require.NoError(t, err, "Failed to clear caches") // Create a long conversation @@ -405,13 +417,15 @@ func TestLongChatCompletions(t *testing.T) { t.Run("Long Conversation Processing", func(t *testing.T) { // Render long conversation start := time.Now() - renderRequest := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: modelName, - }, + key, err := wrapper.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: modelName, + IsLocal: false, + }) + require.NoError(t, err, "Failed to get tokenizer key") + rendered, err := wrapper.ApplyChatTemplate(ctx, &preprocessing.ApplyChatTemplateRequest{ + Key: key, Conversation: longConversation, - } - rendered, err := wrapper.ApplyChatTemplate(context.Background(), renderRequest) + }) renderDuration := time.Since(start) require.NoError(t, err, "Failed to render long conversation") @@ -431,15 +445,15 @@ func TestLongChatCompletions(t *testing.T) { }) } -// BenchmarkLoadTokenizerWithCache benchmarks the template fetching performance. -func BenchmarkLoadTokenizerWithCache(b *testing.B) { +// BenchmarkGetOrCreateTokenizerKey benchmarks the template fetching performance. +func BenchmarkGetOrCreateTokenizerKey(b *testing.B) { wrapper := getGlobalWrapper() // Clear caches to ensure accurate timing measurements err := preprocessing.ClearCaches(context.Background()) require.NoError(b, err, "Failed to clear caches") - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: "ibm-granite/granite-3.3-8b-instruct", } @@ -450,7 +464,7 @@ func BenchmarkLoadTokenizerWithCache(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - err := wrapper.LoadTokenizerWithCache(context.Background(), request) + _, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -478,19 +492,17 @@ func BenchmarkLoadTokenizerWithCache(b *testing.B) { func BenchmarkApplyChatTemplate(b *testing.B) { wrapper := getGlobalWrapper() + ctx := context.Background() + // Clear caches to ensure accurate timing measurements - err := preprocessing.ClearCaches(context.Background()) + err := preprocessing.ClearCaches(ctx) require.NoError(b, err, "Failed to clear caches") - request := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: "ibm-granite/granite-3.3-8b-instruct", - }, - Conversation: [][]preprocessing.Conversation{{ - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "Hi there!"}, - }}, - } + key, err := wrapper.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: "ibm-granite/granite-3.3-8b-instruct", + IsLocal: false, + }) + require.NoError(b, err, "Failed to get tokenizer key") // Track first iteration time and total time var firstIterationTime time.Duration @@ -499,7 +511,13 @@ func BenchmarkApplyChatTemplate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() - _, err := wrapper.ApplyChatTemplate(context.Background(), request) + _, err := wrapper.ApplyChatTemplate(ctx, &preprocessing.ApplyChatTemplateRequest{ + Key: key, + Conversation: [][]preprocessing.Conversation{{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }}, + }) require.NoError(b, err, "Benchmark should not return errors") iterTime := time.Since(start) @@ -574,11 +592,18 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { t.Helper() wrapper := getGlobalWrapper() - // Test case based on the provided vLLM request - request := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: modelName, - }, + ctx := context.Background() + + // Step 1: Get tokenizer key + key, err := wrapper.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: modelName, + IsLocal: false, + }) + require.NoError(t, err, "Failed to get tokenizer key") + + // Step 2: Render the conversation with the tokenizer key + renderedOutput, err := wrapper.ApplyChatTemplate(ctx, &preprocessing.ApplyChatTemplateRequest{ + Key: key, Conversation: [][]preprocessing.Conversation{{ {Role: "user", Content: "What is the weather in Paris?"}, {Role: "assistant", Content: "Let me check that for you."}, @@ -594,13 +619,10 @@ func runVLLMValidationTest(t *testing.T, modelName, expectedVLLMOutput string) { "max_tokens": 10, "temperature": 0.0, }, - } - - // Step 1: Render the conversation with the template - renderedOutput, err := wrapper.ApplyChatTemplate(context.Background(), request) + }) require.NoError(t, err, "Failed to render chat template") - // Step 2: Compare results with flexible date handling + // Step 3: Compare results with flexible date handling compareVLLMOutput(t, renderedOutput, expectedVLLMOutput) } @@ -643,24 +665,25 @@ func compareVLLMOutput(t *testing.T, renderedOutput, expectedVLLMOutput string) t.Fail() // Mark test as failed } -// TestLoadTokenizerWithCacheLocalPath tests fetching chat templates from local paths. -func TestLoadTokenizerWithCacheLocalPath(t *testing.T) { +// TestGetOrCreateTokenizerKeyLocalPath tests fetching chat templates from local paths. +func TestGetOrCreateTokenizerKeyLocalPath(t *testing.T) { wrapper := getGlobalWrapper() // Get the path to the test model tokenizer // The testdata directory is in pkg/tokenization/testdata testModelPath := "../../tokenization/testdata/test-model" - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: testModelPath, IsLocal: true, } // Fetch the chat template - err := wrapper.LoadTokenizerWithCache(context.Background(), request) + key, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) // Assertions - require.NoError(t, err, "LoadTokenizerWithCache should not return an error for local path") + require.NoError(t, err, "GetOrCreateTokenizerKey should not return an error for local path") + assert.NotEmpty(t, key, "Returned tokenizer key should not be empty") } // TestApplyChatTemplateWithLocalTemplate tests rendering with a locally fetched template. @@ -669,14 +692,18 @@ func TestApplyChatTemplateWithLocalTemplate(t *testing.T) { // Get the path to the test model tokenizer testModelPath := "../../tokenization/testdata/test-model" + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ + Model: testModelPath, + IsLocal: true, + } + + // get tokenizer key + key, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) + require.NoError(t, err, "GetOrCreateTokenizerKey should not return an error for local path") // Now render a conversation using the fetched template renderRequest := &preprocessing.ApplyChatTemplateRequest{ - LoadTokenizerWithCacheRequest: preprocessing.LoadTokenizerWithCacheRequest{ - Model: testModelPath, - IsLocal: true, - }, - + Key: key, Conversation: [][]preprocessing.Conversation{{ {Role: "user", Content: "Hello from local tokenizer!"}, {Role: "assistant", Content: "Hi! I'm using a locally loaded template."}, @@ -693,8 +720,8 @@ func TestApplyChatTemplateWithLocalTemplate(t *testing.T) { t.Logf("Rendered chat with local template:\n%s", rendered) } -// TestLoadTokenizerWithCacheLocalPathCaching tests that local templates are cached properly. -func TestLoadTokenizerWithCacheLocalPathCaching(t *testing.T) { +// TestGetOrCreateTokenizerKeyLocalPathCaching tests that local templates are cached properly. +func TestGetOrCreateTokenizerKeyLocalPathCaching(t *testing.T) { wrapper := getGlobalWrapper() // Clear caches first @@ -702,62 +729,67 @@ func TestLoadTokenizerWithCacheLocalPathCaching(t *testing.T) { require.NoError(t, err, "Failed to clear caches") testModelPath := "../../tokenization/testdata/test-model" - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: testModelPath, IsLocal: true, } // First call - cache miss start := time.Now() - err = wrapper.LoadTokenizerWithCache(context.Background(), request) + key1, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) duration1 := time.Since(start) require.NoError(t, err, "First call should not return an error") // Second call - cache hit start = time.Now() - err = wrapper.LoadTokenizerWithCache(context.Background(), request) + key2, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) duration2 := time.Since(start) require.NoError(t, err, "Second call should not return an error") + // Verify that both calls returned the same key + assert.Equal(t, key1, key2, "Both calls should return the same tokenizer key") + // Cache hit should be faster t.Logf("First call (cache miss): %v, Second call (cache hit): %v, Speedup: %.1fx", duration1, duration2, float64(duration1)/float64(duration2)) assert.Less(t, duration2, duration1, "Cache hit should be faster than cache miss") } -// TestLoadTokenizerWithCacheLocalPathWithFile tests loading from a specific tokenizer.json file path. -func TestLoadTokenizerWithCacheLocalPathWithFile(t *testing.T) { +// TestGetOrCreateTokenizerKeyLocalPathWithFile tests loading from a specific tokenizer.json file path. +func TestGetOrCreateTokenizerKeyLocalPathWithFile(t *testing.T) { wrapper := getGlobalWrapper() // Test with the full path to tokenizer.json //nolint:gosec // This is a test file path, not a credential testTokenizerPath := "../../tokenization/testdata/test-model/tokenizer.json" - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: testTokenizerPath, IsLocal: true, } - // Fetch the chat template - should extract directory and load from there - err := wrapper.LoadTokenizerWithCache(context.Background(), request) - require.NoError(t, err, "LoadTokenizerWithCache should handle file path and extract directory") + // Get the tokenizer key + key, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) + require.NoError(t, err, "GetOrCreateTokenizerKey should handle file path and extract directory") + assert.NotEmpty(t, key, "Returned tokenizer key should not be empty") t.Logf("Loaded tokenizer from file path: %s", testTokenizerPath) } -// TestLoadTokenizerWithCacheLocalPathNonExistent tests error handling for non-existent local paths. -func TestLoadTokenizerWithCacheLocalPathNonExistent(t *testing.T) { +// TestGetOrCreateTokenizerKeyLocalPathNonExistent tests error handling for non-existent local paths. +func TestGetOrCreateTokenizerKeyLocalPathNonExistent(t *testing.T) { wrapper := getGlobalWrapper() - request := &preprocessing.LoadTokenizerWithCacheRequest{ + request := &preprocessing.GetOrCreateTokenizerKeyRequest{ Model: "/non/existent/path", IsLocal: true, } // This should return an error - err := wrapper.LoadTokenizerWithCache(context.Background(), request) + key, err := wrapper.GetOrCreateTokenizerKey(context.Background(), request) // Assertions - assert.Error(t, err, "LoadTokenizerWithCache should return an error for non-existent path") + assert.Error(t, err, "GetOrCreateTokenizerKey should return an error for non-existent path") + assert.Empty(t, key, "Returned tokenizer key should be empty for non-existent path") t.Logf("Expected error for non-existent path: %v", err) } diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index ccd8af75f..28577882c 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -21,6 +21,7 @@ import logging import os import sys + from vllm.transformers_utils.tokenizer import get_tokenizer # Basic logging setup @@ -28,11 +29,13 @@ _tokenizer_cache = {} + def clear_caches(): """Clear the tokenizer cache for testing purposes.""" _tokenizer_cache.clear() return "Tokenizer caches cleared" + def apply_chat_template(request_json): """ Render a chat template using the vllm library. @@ -40,12 +43,7 @@ def apply_chat_template(request_json): Args: request_json (str): JSON string containing the request parameters: - - load_tokenizer_with_cache_request (dict): Parameters for loading the tokenizer: - - is_local (bool, optional): Whether the model is local. - - model (str): The model ID or path (HF model ID, local directory path, or path to tokenizer file). - - revision (str, optional): Model revision. - - token (str, optional): Hugging Face token for private models. - - download_dir (str, optional): Directory to download the model. + - key (str): The tokenizer cache key - conversation (list): List of conversation lists - chat_template (str, optional): The template to use - tools (list, optional): Tool schemas @@ -62,8 +60,11 @@ def apply_chat_template(request_json): try: # Parse the JSON request request = json.loads(request_json) - tokenizer_request = request.pop("load_tokenizer_with_cache_request") - tokenizer = load_tokenizer_with_cache(json.dumps(tokenizer_request)) + key = request.pop("key") + print("mhg", key, flush=True) + tokenizer = _tokenizer_cache.get(key) + if tokenizer is None: + raise RuntimeError(f"Tokenizer with key {key} not found in cache") # Get template_vars and spread them as individual arguments template_vars = request.pop('chat_template_kwargs', {}) @@ -75,9 +76,12 @@ def apply_chat_template(request_json): except Exception as e: raise RuntimeError(f"Error applying chat template: {e}") from e -def load_tokenizer_with_cache(request_json): + +def get_or_create_tokenizer_key(request_json): """ - Initialize and cache the tokenizer based on the request. + Return the cache key for the tokenizer specified in the request. + If the tokenizer is not already cached, initialize and cache it first. + Args: request_json (str): JSON string containing the request parameters: - is_local (bool, optional): Whether the model is local. @@ -86,7 +90,7 @@ def load_tokenizer_with_cache(request_json): - token (str, optional): Hugging Face token for private models. - download_dir (str, optional): Directory to download the model. Returns: - tokenizer: The initialized tokenizer object. + str: The cache key for the initialized tokenizer. """ # Parse the JSON request request = json.loads(request_json) @@ -102,44 +106,66 @@ def load_tokenizer_with_cache(request_json): # If it's a file path (tokenizer.json), get the directory model_name = os.path.dirname(model_name) - cache_key = f"{model_name}:{revision or 'main'}:{is_local}" - tokenizer = _tokenizer_cache.get(cache_key) + key = f"{model_name}:{revision or 'main'}:{is_local}" + tokenizer = _tokenizer_cache.get(key) if tokenizer is not None: - return tokenizer + return key os.environ["HF_TOKEN"] = token - tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir) - _tokenizer_cache[cache_key] = tokenizer - return tokenizer + tokenizer = get_tokenizer(model_name, + trust_remote_code=True, + revision=revision, + download_dir=download_dir) + _tokenizer_cache[key] = tokenizer + return key except Exception as e: raise RuntimeError(f"Error initializing tokenizer: {e}") from e + def example_usage(): """Example usage of apply_chat_template function.""" - request_str = json.dumps({ - "load_tokenizer_with_cache_request": { + key = get_or_create_tokenizer_key( + json.dumps({ "is_local": False, "model": "ibm-granite/granite-3.3-8b-instruct", - }, - "conversation": [ [{"role": "system", "content": "You are a helpful assistant."}] , [{"role": "user", "content": "who are you?"}] ], + })) + request_str = json.dumps({ + "key": + key, + "conversation": [[{ + "role": "system", + "content": "You are a helpful assistant." + }], [{ + "role": "user", + "content": "who are you?" + }]], }) print(apply_chat_template(request_str)) + del _tokenizer_cache[key] + def main(): """Example usage and testing function.""" if len(sys.argv) < 2: - print("Usage: python tokenizer_wrapper.py [conversation_json]") + print( + "Usage: python tokenizer_wrapper.py [conversation_json]" + ) print("Example:") - print('python tokenizer_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"') + print( + 'python tokenizer_wrapper.py "{% for message in messages %}{{ message.role }}: {{ message.content }}\\n{% endfor %}"' + ) return chat_template = sys.argv[1] # Default conversation if none provided - conversation = [ - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there! How can I help you today?"} - ] + conversation = [{ + "role": "user", + "content": "Hello!" + }, { + "role": "assistant", + "content": "Hi there! How can I help you today?" + }] if len(sys.argv) > 2: try: @@ -165,5 +191,6 @@ def main(): except Exception as e: print(f"Error: {e}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pkg/tokenization/tokenizer.go b/pkg/tokenization/tokenizer.go index 7e7a93b3b..4b9cd0eee 100644 --- a/pkg/tokenization/tokenizer.go +++ b/pkg/tokenization/tokenizer.go @@ -267,6 +267,7 @@ type CachedTokenizer struct { tokenizer *tokenizers.Tokenizer tokenizerProvider tokenizerProvider chatTemplateRenderer *preprocessing.ChatTemplatingProcessor + tokenizerCacheKey string } type HFCachedTokenizer struct { @@ -293,12 +294,13 @@ func NewCachedHFTokenizer(ctx context.Context, modelID string, config *HFTokeniz return nil, fmt.Errorf("failed to initialize chat templater: %w", err) } - if err := chatTemplateRenderer.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ + tokenizerCacheKey, err := chatTemplateRenderer.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ IsLocal: false, Model: modelID, DownloadDir: config.TokenizersCacheDir, Token: config.HuggingFaceToken, - }); err != nil { + }) + if err != nil { return nil, fmt.Errorf("failed to load tokenizer with cache: %w", err) } @@ -307,6 +309,7 @@ func NewCachedHFTokenizer(ctx context.Context, modelID string, config *HFTokeniz tokenizer: tokenizer, tokenizerProvider: tokenizerProvider, chatTemplateRenderer: chatTemplateRenderer, + tokenizerCacheKey: tokenizerCacheKey, }, hfTokenizerConfig: config, }, nil @@ -345,11 +348,12 @@ func NewCachedLocalTokenizer(ctx context.Context, modelName string, config Local return nil, fmt.Errorf("tokenizer for model %q not found", modelName) } - if err := chatTemplater.LoadTokenizerWithCache(ctx, &preprocessing.LoadTokenizerWithCacheRequest{ + tokenizerCacheKey, err := chatTemplater.GetOrCreateTokenizerKey(ctx, &preprocessing.GetOrCreateTokenizerKeyRequest{ IsLocal: true, Model: path, - }); err != nil { - return nil, fmt.Errorf("failed to load tokenizer with cache: %w", err) + }) + if err != nil { + return nil, fmt.Errorf("failed to get or create tokenizer key with cache: %w", err) } return &LocalCachedTokenizer{ @@ -357,39 +361,18 @@ func NewCachedLocalTokenizer(ctx context.Context, modelName string, config Local tokenizer: tokenizer, tokenizerProvider: tokenizerProvider, chatTemplateRenderer: chatTemplater, + tokenizerCacheKey: tokenizerCacheKey, }, localTokenizerConfig: &config, }, nil } -func (t *LocalCachedTokenizer) ApplyChatTemplate( - modelName string, req *preprocessing.ApplyChatTemplateRequest, -) (string, error) { - ctx := context.TODO() - - req.LoadTokenizerWithCacheRequest.IsLocal = true - path, ok := t.localTokenizerConfig.ModelTokenizerMap[modelName] - if !ok { - return "", fmt.Errorf("tokenizer for model %q not found", modelName) - } - req.LoadTokenizerWithCacheRequest.Model = filepath.Dir(path) - res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) - if err != nil { - return "", fmt.Errorf("failed to render chat template: %w", err) - } - - return res, nil -} - -func (t *HFCachedTokenizer) ApplyChatTemplate( - modelName string, req *preprocessing.ApplyChatTemplateRequest, +func (t *CachedTokenizer) ApplyChatTemplate( + _ string, req *preprocessing.ApplyChatTemplateRequest, ) (string, error) { ctx := context.TODO() - req.LoadTokenizerWithCacheRequest.IsLocal = false - req.LoadTokenizerWithCacheRequest.DownloadDir = t.hfTokenizerConfig.TokenizersCacheDir - req.LoadTokenizerWithCacheRequest.Token = t.hfTokenizerConfig.HuggingFaceToken - req.LoadTokenizerWithCacheRequest.Model = modelName + req.Key = t.tokenizerCacheKey res, err := t.chatTemplateRenderer.ApplyChatTemplate(ctx, req) if err != nil { return "", fmt.Errorf("failed to render chat template: %w", err) diff --git a/services/uds_tokenizer/server.py b/services/uds_tokenizer/server.py index 161ba101c..6649a8de1 100644 --- a/services/uds_tokenizer/server.py +++ b/services/uds_tokenizer/server.py @@ -12,24 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Web server for tokenizer service.""" -import os +import asyncio import json import logging -import asyncio +import os import signal -from typing import Dict, Any + from aiohttp import web -from tokenizer_service.tokenizer import TokenizerService, TokenizerConfig +from tokenizer_service.tokenizer import TokenizerConfig, TokenizerService # Configure logging LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL), - format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s' -) + format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s') # Try to use uvloop for better performance try: @@ -61,17 +59,19 @@ def initialize_tokenizer(): try: # Parse ADD_SPECIAL_TOKENS environment variable add_special_tokens_env = os.getenv("ADD_SPECIAL_TOKENS") - if add_special_tokens_env is None or add_special_tokens_env.lower() == "none": + if add_special_tokens_env is None or add_special_tokens_env.lower( + ) == "none": add_special_tokens = None # Use tokenizer's default behavior else: add_special_tokens = add_special_tokens_env.lower() == "true" - + current_config = TokenizerConfig( model=os.getenv("MODEL", "Qwen/Qwen3-0.6B"), add_special_tokens=add_special_tokens, - enable_thinking=os.getenv("ENABLE_THINKING", "false").lower() == "true", - add_generation_prompt=os.getenv("ADD_GENERATION_PROMPT", "true").lower() == "true" - ) + enable_thinking=os.getenv("ENABLE_THINKING", + "false").lower() == "true", + add_generation_prompt=os.getenv("ADD_GENERATION_PROMPT", + "true").lower() == "true") tokenizer_service = TokenizerService(current_config) tokenizer_ready = True logging.info("Tokenizer initialized successfully") @@ -85,32 +85,39 @@ async def template_handler(request): logging.info("Handling chat template request") try: body = await request.read() - + try: messages = json.loads(body.decode('utf-8')) except UnicodeDecodeError as e: logging.error(f"Invalid UTF-8 encoding: {e}") return web.json_response( - {"status": "error", "message": f"Invalid UTF-8 encoding: {e}"}, - status=400 - ) + { + "status": "error", + "message": f"Invalid UTF-8 encoding: {e}" + }, + status=400) except json.JSONDecodeError as e: logging.error(f"Invalid JSON: {e}") return web.json_response( - {"status": "error", "message": f"Invalid JSON: {e}"}, - status=400 - ) + { + "status": "error", + "message": f"Invalid JSON: {e}" + }, + status=400) prompt = tokenizer_service.apply_template(messages) + prompt = prompt[0] logging.info(f"Generated prompt: {prompt[:100]}...") return web.Response(text=prompt, content_type='text/plain') - + except Exception as e: logging.error(f"Processing error: {e}", exc_info=True) return web.json_response( - {"status": "error", "message": f"Processing failed: {e}"}, - status=500 - ) + { + "status": "error", + "message": f"Processing failed: {e}" + }, + status=500) async def tokenize_handler(request): @@ -118,37 +125,42 @@ async def tokenize_handler(request): logging.info("Handling tokenize request") try: body = await request.read() - + prompt = body.decode('utf-8') logging.info(f"Prompt to tokenize: {prompt[:100]}...") - + loop = asyncio.get_running_loop() - batch_encoding = await loop.run_in_executor(None, tokenizer_service.tokenize_and_process, prompt) + batch_encoding = await loop.run_in_executor( + None, tokenizer_service.tokenize_and_process, prompt) serializable_data = { key: value.tolist() if hasattr(value, "tolist") else value for key, value in batch_encoding.items() } response = json.dumps(serializable_data) return web.Response(text=response, content_type='application/json') - + except Exception as e: logging.error(f"Processing error: {e}", exc_info=True) return web.json_response( - {"status": "error", "message": f"Processing failed: {e}"}, - status=500 - ) + { + "status": "error", + "message": f"Processing failed: {e}" + }, + status=500) async def health_handler(request): """Health check endpoint""" if not tokenizer_ready: - return web.json_response({ - "status": "unhealthy", - "service": "tokenizer-service", - "reason": "tokenizer not ready", - "timestamp": asyncio.get_event_loop().time() - }, status=503) - + return web.json_response( + { + "status": "unhealthy", + "service": "tokenizer-service", + "reason": "tokenizer not ready", + "timestamp": asyncio.get_event_loop().time() + }, + status=503) + return web.json_response({ "status": "healthy", "service": "tokenizer-service", @@ -173,16 +185,18 @@ async def update_config_handler(request): try: body = await request.read() new_config_data = json.loads(body.decode('utf-8')) - + updated_config = TokenizerConfig( model=new_config_data.get("model", current_config.model), - add_special_tokens=new_config_data.get("add_special_tokens", current_config.add_special_tokens), - enable_thinking=new_config_data.get("enable_thinking", current_config.enable_thinking), - add_generation_prompt=new_config_data.get("add_generation_prompt", current_config.add_generation_prompt) - ) - + add_special_tokens=new_config_data.get( + "add_special_tokens", current_config.add_special_tokens), + enable_thinking=new_config_data.get( + "enable_thinking", current_config.enable_thinking), + add_generation_prompt=new_config_data.get( + "add_generation_prompt", current_config.add_generation_prompt)) + tokenizer_ready = False - + # Reinitialize tokenizer service try: tokenizer_service = TokenizerService(updated_config) @@ -190,24 +204,34 @@ async def update_config_handler(request): tokenizer_ready = True logging.info(f"Configuration updated: {new_config_data}") return web.json_response({ - "status": "success", - "message": "Configuration updated successfully" + "status": + "success", + "message": + "Configuration updated successfully" }) except Exception as e: # If initialization fails, restore previous configuration - tokenizer_ready = True - logging.error(f"Failed to initialize tokenizer with new config: {e}", exc_info=True) - return web.json_response({ - "status": "error", - "message": f"Failed to initialize tokenizer with new config: {e}" - }, status=500) - + tokenizer_ready = True + logging.error( + f"Failed to initialize tokenizer with new config: {e}", + exc_info=True) + return web.json_response( + { + "status": + "error", + "message": + f"Failed to initialize tokenizer with new config: {e}" + }, + status=500) + except Exception as e: logging.error(f"Config update error: {e}", exc_info=True) - return web.json_response({ - "status": "error", - "message": f"Config update failed: {e}" - }, status=500) + return web.json_response( + { + "status": "error", + "message": f"Config update failed: {e}" + }, + status=500) def create_app(): @@ -237,23 +261,23 @@ async def cleanup(): """Clean up resources""" global server_runner, server_site, probe_runner, probe_site logging.info("Cleaning up resources") - + if probe_site: await probe_site.stop() logging.info("Probe site stopped") - + if probe_runner: await probe_runner.cleanup() logging.info("Probe runner cleaned up") - + if server_site: await server_site.stop() logging.info("Server site stopped") - + if server_runner: await server_runner.cleanup() logging.info("Server runner cleaned up") - + if os.path.exists(UDS_SOCKET_PATH): os.remove(UDS_SOCKET_PATH) logging.info(f"Socket file {UDS_SOCKET_PATH} removed") @@ -262,46 +286,48 @@ async def cleanup(): async def run_server(): """Run the server""" global server_runner, server_site, probe_runner, probe_site, shutdown_event - + # Initialize tokenizer try: initialize_tokenizer() except Exception as e: logging.error(f"Failed to initialize tokenizer, exiting: {e}") return - + # Remove old socket file if it exists if os.path.exists(UDS_SOCKET_PATH): os.remove(UDS_SOCKET_PATH) - + # Create dedicated directory and set permissions os.makedirs(os.path.dirname(UDS_SOCKET_PATH), mode=0o700, exist_ok=True) - + # Create main application (UDS) app = create_app() app.on_shutdown.append(shutdown_handler) - + server_runner = web.AppRunner(app) await server_runner.setup() server_site = web.UnixSite(server_runner, UDS_SOCKET_PATH) - + # Create probe application (TCP socket) probe_app = create_probe_app() probe_runner = web.AppRunner(probe_app) await probe_runner.setup() - + probe_site = web.TCPSite(probe_runner, "0.0.0.0", PROBE_PORT) - + # Set up signal handling shutdown_event = asyncio.Event() loop = asyncio.get_running_loop() + def signal_handler(): logging.info("Received signal, initiating shutdown...") shutdown_event.set() + loop.add_signal_handler(signal.SIGTERM, signal_handler) loop.add_signal_handler(signal.SIGINT, signal_handler) - + try: await server_site.start() await probe_site.start() @@ -319,38 +345,40 @@ def signal_handler(): async def create_app_for_gunicorn(): """Create application for Gunicorn""" global tokenizer_service - + # Use a lock file to synchronize tokenizer initialization across workers lock_file_path = "/tmp/tokenizer_init.lock" - + if tokenizer_service is None: import fcntl - + # Ensure lock file exists open(lock_file_path, 'a').close() - + # Open lock file lock_file = open(lock_file_path, 'r+') - + try: fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) logging.info("Acquired tokenizer initialization lock") - + if tokenizer_service is None: logging.info("Initializing tokenizer...") try: initialize_tokenizer() except Exception as e: - logging.error(f"Failed to initialize tokenizer in gunicorn mode: {e}") + logging.error( + f"Failed to initialize tokenizer in gunicorn mode: {e}" + ) raise else: logging.info("Tokenizer already initialized by another worker") - + finally: # Release the lock fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) lock_file.close() - + return create_app() @@ -358,4 +386,4 @@ async def create_app_for_gunicorn(): try: asyncio.run(run_server()) except KeyboardInterrupt: - pass \ No newline at end of file + pass diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index c833c63e5..3a775832f 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -797,6 +797,16 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { testModelDir, err := filepath.Abs(localTestModelDir) s.Require().NoError(err) + // Test 1: Non-existent model + _, err = tokenization.NewCachedLocalTokenizer(context.Background(), modelName, tokenization.LocalTokenizerConfig{ + ModelTokenizerMap: map[string]string{ + modelName: "non-existent-model", + }, + }) + s.Require().Error(err, "Should return error for non-existent model") + s.T().Logf("Expected error for non-existent model: %v", err) + + // Test 2: Empty conversation localTokenizer, err := tokenization.NewCachedLocalTokenizer(context.Background(), modelName, tokenization.LocalTokenizerConfig{ ModelTokenizerMap: map[string]string{ modelName: filepath.Join(testModelDir, "tokenizer.json"), @@ -806,19 +816,6 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateErrorHandling() { s.SetTokenizer(localTokenizer, modelName) - conversation := []ChatMessage{ - {Role: "user", Content: "Test"}, - } - - // Test 1: Non-existent model - reqNonExistent := &preprocessing.ApplyChatTemplateRequest{ - Conversation: convertToPreprocessingConversation(conversation), - } - _, err = localTokenizer.ApplyChatTemplate("non-existent-model", reqNonExistent) - s.Require().Error(err, "Should return error for non-existent model") - s.T().Logf("Expected error for non-existent model: %v", err) - - // Test 2: Empty conversation emptyConversation := []ChatMessage{} reqEmpty := &preprocessing.ApplyChatTemplateRequest{ Conversation: convertToPreprocessingConversation(emptyConversation), From e7b315c1ee3c0f6fb2a2a0b684b0f74e8f24e3ba Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Mon, 5 Jan 2026 09:31:39 +0000 Subject: [PATCH 07/12] edit data Signed-off-by: HyunKyun Moon --- examples/testdata/data.go | 8 ++++---- hack/verify-examples.sh | 0 2 files changed, 4 insertions(+), 4 deletions(-) mode change 100644 => 100755 hack/verify-examples.sh diff --git a/examples/testdata/data.go b/examples/testdata/data.go index bbc0faac8..bda51ff45 100644 --- a/examples/testdata/data.go +++ b/examples/testdata/data.go @@ -30,8 +30,8 @@ var RenderReq *preprocessing.ApplyChatTemplateRequest = nil var Prompt string var PromptHashes = []uint64{ - 5883650188907136581, - 4344014219501030587, - 8576040316208967329, - 13369611429964591057, + 3246512376769953277, + 2932514196368075983, + 6384763183060574933, + 13975137892230421288, } diff --git a/hack/verify-examples.sh b/hack/verify-examples.sh old mode 100644 new mode 100755 From c4161aa9904ff78031a410fc7e4a5cd49e6c5923 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Tue, 6 Jan 2026 13:05:52 +0000 Subject: [PATCH 08/12] resolve CStrings memory leaks Signed-off-by: HyunKyun Moon --- pkg/preprocessing/chat_completions/cgo_functions.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/preprocessing/chat_completions/cgo_functions.go b/pkg/preprocessing/chat_completions/cgo_functions.go index 30bef3f0f..02a738ce3 100644 --- a/pkg/preprocessing/chat_completions/cgo_functions.go +++ b/pkg/preprocessing/chat_completions/cgo_functions.go @@ -127,7 +127,9 @@ func (w *ChatTemplatingProcessor) GetOrCreateTokenizerKey( return "", fmt.Errorf("failed to marshal request: %w", err) } // Call the cached Python function - cResult := C.Py_CallGetOrCreateTokenizerKey(C.CString(string(reqJSON))) + cJSONString := C.CString(string(reqJSON)) + defer C.free(unsafe.Pointer(cJSONString)) + cResult := C.Py_CallGetOrCreateTokenizerKey(cJSONString) if cResult == nil { traceLogger.Error(nil, "C function returned nil") return "", fmt.Errorf("python get_or_create_tokenizer_key failed") @@ -155,7 +157,9 @@ func (w *ChatTemplatingProcessor) ApplyChatTemplate(ctx context.Context, return "", fmt.Errorf("failed to marshal request: %w", err) } // Call the cached Python function - cResult := C.Py_CallApplyChatTemplate(C.CString(string(reqJSON))) + cJSONString := C.CString(string(reqJSON)) + defer C.free(unsafe.Pointer(cJSONString)) + cResult := C.Py_CallApplyChatTemplate(cJSONString) if cResult == nil { traceLogger.Error(nil, "C function returned nil") return "", fmt.Errorf("python apply_chat_template failed") From 7481fc6a5d1eb74e497c5d3555861144d27a52d0 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 8 Jan 2026 13:42:18 +0000 Subject: [PATCH 09/12] apply review Signed-off-by: HyunKyun Moon --- .dockerignore | 4 +- .gitignore | 4 +- Dockerfile | 34 +++---- Makefile | 29 ++++-- .../chat_completions/requirements.txt | 4 - pkg/preprocessing/chat_completions/setup.sh | 96 +++++++++++++++++++ .../chat_completions/tokenizer_wrapper.py | 1 - 7 files changed, 141 insertions(+), 31 deletions(-) delete mode 100644 pkg/preprocessing/chat_completions/requirements.txt create mode 100755 pkg/preprocessing/chat_completions/setup.sh diff --git a/.dockerignore b/.dockerignore index 4eb2972cb..b781588e7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -18,4 +18,6 @@ venv __pycache__ # Docker files -Dockerfile \ No newline at end of file +Dockerfile + +**/vllm_source \ No newline at end of file diff --git a/.gitignore b/.gitignore index f1a3a0a2c..d36f3bc47 100644 --- a/.gitignore +++ b/.gitignore @@ -69,4 +69,6 @@ _cgo_* /lib # uds tokenizer default model path -services/uds_tokenizer/models \ No newline at end of file +services/uds_tokenizer/models + +**/vllm_source \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index afe3eb0b1..5b4f6bf5f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +FROM python:3.12-slim AS python-builder + +WORKDIR /workspace + +RUN apt-get update && apt-get install -y --no-install-recommends build-essential + +COPY Makefile Makefile +COPY pkg/preprocessing/chat_completions/ pkg/preprocessing/chat_completions/ +RUN make install-python-deps + # Build Stage: using Go 1.24.1 image FROM quay.io/projectquay/golang:1.24 AS builder ARG TARGETOS @@ -35,14 +45,7 @@ COPY go.sum go.sum # and so that source changes don't invalidate our downloaded layer RUN go mod download -# Copy only the requirements file. -COPY pkg/preprocessing/chat_completions/requirements.txt ./requirements.txt -# Install Python dependencies. This layer will be cached unless requirements.txt changes. -RUN python3.12 -m pip install --upgrade pip setuptools wheel && \ - python3.12 -m pip install -r ./requirements.txt - -# Copy the go source -COPY examples/kv_events examples/kv_events +# Copy the source code. COPY . . # HuggingFace tokenizer bindings @@ -51,6 +54,10 @@ ARG RELEASE_VERSION=v1.22.1 RUN curl -L https://github.com/daulet/tokenizers/releases/download/${RELEASE_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib RUN ranlib lib/*.a +# Copy this project's own Python source code into the final image +COPY --from=python-builder /workspace/pkg/preprocessing/chat_completions /workspace/pkg/preprocessing/chat_completions +RUN make setup-venv +COPY --from=python-builder /workspace/build/venv/lib/python3.12/site-packages /workspace/build/venv/lib/python3.12/site-packages RUN make build # Use distroless as minimal base image to package the manager binary @@ -64,16 +71,9 @@ RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-9. dnf install -y zeromq libxcrypt-compat python3.12 python3.12-pip && \ dnf clean all - - -# Install Python dependencies in the final image. -COPY --from=builder /workspace/requirements.txt /tmp/requirements.txt -RUN python3.12 -m pip install --upgrade pip setuptools wheel && \ - python3.12 -m pip install --no-cache-dir -r /tmp/requirements.txt \ - && rm -rf /tmp/requirements.txt - # Copy this project's own Python source code into the final image -COPY --from=builder /workspace/pkg/preprocessing/chat_completions /app/pkg/preprocessing/chat_completions +COPY --from=python-builder /workspace/pkg/preprocessing/chat_completions /app/pkg/preprocessing/chat_completions +COPY --from=python-builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages # Set the PYTHONPATH. This mirrors the Makefile's export, ensuring both this project's # Python code and the installed libraries (site-packages) are found at runtime. diff --git a/Makefile b/Makefile index 9b9a69550..10bacca9a 100644 --- a/Makefile +++ b/Makefile @@ -112,8 +112,8 @@ detect-python: ## Detects Python and prints the configuration. fi @printf "\033[33;1m==============================\033[0m\n" -.PHONY: install-python-deps -install-python-deps: detect-python ## Sets up the Python virtual environment and installs dependencies. +.PHONY: setup-venv +setup-venv: detect-python ## Sets up the Python virtual environment. @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n" @if [ ! -f "$(VENV_BIN)/pip" ]; then \ echo "Creating virtual environment..."; \ @@ -124,12 +124,27 @@ install-python-deps: detect-python ## Sets up the Python virtual environment and exit 1; \ }; \ fi - @echo "Upgrading pip and installing dependencies..." + @echo "Upgrading pip..." @$(VENV_BIN)/pip install --upgrade pip - @$(VENV_BIN)/pip install -q -r pkg/preprocessing/chat_completions/requirements.txt - @echo "Verifying transformers installation..." - @$(VENV_BIN)/python -c "import transformers; print('✅ Transformers version ' + transformers.__version__ + ' installed.')" || { \ - echo "ERROR: transformers library not properly installed in venv."; \ + @echo "Python virtual environment setup complete." + +.PHONY: setup-venv +install-python-deps: setup-venv ## installs dependencies. + @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n" + @if [ ! -f "$(VENV_BIN)/pip" ]; then \ + echo "Creating virtual environment..."; \ + $(PYTHON_EXE) -m venv $(VENV_DIR) || { \ + echo "ERROR: Failed to create virtual environment."; \ + echo "Your Python installation may be missing the 'venv' module."; \ + echo "Try: 'sudo apt install python$(PYTHON_VERSION)-venv' or 'sudo dnf install python$(PYTHON_VERSION)-devel'"; \ + exit 1; \ + }; \ + fi + @echo "Upgrading pip and installing dependencies..." + @PATH=$(VENV_BIN):$$PATH ./pkg/preprocessing/chat_completions/setup.sh + @echo "Verifying vllm installation..." + @$(VENV_BIN)/python -c "import vllm; print('✅ vllm version ' + vllm.__version__ + ' installed.')" || { \ + echo "ERROR: vllm library not properly installed in venv."; \ exit 1; \ } diff --git a/pkg/preprocessing/chat_completions/requirements.txt b/pkg/preprocessing/chat_completions/requirements.txt deleted file mode 100644 index 26c952f34..000000000 --- a/pkg/preprocessing/chat_completions/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ ---index-url https://download.pytorch.org/whl/cpu ---extra-index-url https://pypi.org/simple -vllm-cpu>=0.11.0; sys_platform != 'darwin' -vllm @ git+https://github.com/vllm-project/vllm.git@v0.11.0; sys_platform == 'darwin' \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/setup.sh b/pkg/preprocessing/chat_completions/setup.sh new file mode 100755 index 000000000..9a42f792e --- /dev/null +++ b/pkg/preprocessing/chat_completions/setup.sh @@ -0,0 +1,96 @@ +#!/bin/bash +set -e + +# 1. Skip if vllm is already installed +PYTHON_BIN=$(which python3 || which python) +if $PYTHON_BIN -c "import vllm" &> /dev/null; then + echo "[SKIP] vllm is already installed. Exiting." + exit 0 +fi + +# 2. Architecture check (Only Intel/AMD x86, ARM AArch64, Apple Silicon supported) +ARCH=$(uname -m) +OS=$(uname) +VLLM_REPO=https://github.com/vllm-project/vllm.git +VLLM_TAG=v0.11.1 + +if [[ "$ARCH" == "x86_64" ]]; then + ARCH_TYPE="x86_64" +elif [[ "$ARCH" == "aarch64" ]]; then + ARCH_TYPE="aarch64" +elif [[ "$ARCH" == "arm64" && "$OS" == "Darwin" ]]; then + ARCH_TYPE="apple_silicon" +else + echo "[ERROR] Only Intel/AMD x86_64, ARM AArch64 (aarch64), and Apple Silicon (arm64, macOS) are supported." + exit 1 +fi + +# 3. Check and install Python requirements (runtime) +REQUIRED_PKGS=(cmake wheel packaging ninja setuptools-scm numpy) +TO_INSTALL=() +for pkg in "${REQUIRED_PKGS[@]}"; do + # Try pip show, then fallback to checking if the binary exists in PATH + if ! $PYTHON_BIN -m pip show "$pkg" &> /dev/null; then + # Some packages like cmake, ninja may be installed as binaries + if ! command -v "$pkg" &> /dev/null; then + TO_INSTALL+=("$pkg") + fi + fi +done +$PYTHON_BIN -m pip install --upgrade pip +if [[ ${#TO_INSTALL[@]} -gt 0 ]]; then + $PYTHON_BIN -m pip install "cmake>=3.26" wheel packaging ninja "setuptools-scm>=8" numpy +else + echo "[SKIP] python runtime packages already installed." +fi + +# 4. Check and install build dependencies (system packages) per architecture +if [[ "$ARCH_TYPE" == "x86_64" || "$ARCH_TYPE" == "aarch64" ]]; then + SYS_PKGS=(git gcc-12 g++-12 libnuma-dev) + INSTALL_SYS_PKGS=() + for pkg in "${SYS_PKGS[@]}"; do + if ! dpkg -s "$pkg" &> /dev/null; then + INSTALL_SYS_PKGS+=("$pkg") + fi + done + if [[ ${#INSTALL_SYS_PKGS[@]} -gt 0 ]]; then + if command -v apt-get &> /dev/null; then + apt-get update + apt-get install -y "${INSTALL_SYS_PKGS[@]}" + elif command -v dnf &> /dev/null; then + dnf install -y "${INSTALL_SYS_PKGS[@]}" + elif command -v yum &> /dev/null; then + yum install -y "${INSTALL_SYS_PKGS[@]}" + else + echo "[ERROR] No supported package manager found (apt-get, dnf, yum). Please install build dependencies manually: ${SYS_PKGS[*]}" + exit 1 + fi + else + echo "[SKIP] gcc-12, g++-12, libnuma-dev already installed." + fi + # Ensure gcc-12 is set as the default gcc (Debian/Ubuntu only) + if command -v update-alternatives &> /dev/null && ! gcc --version | grep -q 'gcc-12'; then + update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + fi +fi + +# 5. Clone vllm source and install requirements/cpu.txt (common) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VLLM_SRC_DIR="$SCRIPT_DIR/vllm_source" +if [ ! -d "$VLLM_SRC_DIR" ]; then + git clone $VLLM_REPO "$VLLM_SRC_DIR" +fi +cd "$VLLM_SRC_DIR" +git fetch --tags +git checkout tags/$VLLM_TAG + +$PYTHON_BIN -m pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +# 6. Build wheel from source (actual build) +if [[ "$ARCH_TYPE" == "x86_64" || "$ARCH_TYPE" == "aarch64" ]]; then + VLLM_TARGET_DEVICE=cpu $PYTHON_BIN setup.py install +elif [[ "$ARCH_TYPE" == "apple_silicon" ]]; then + $PYTHON_BIN -m pip install -e . +fi + +echo "vLLM CPU build and installation completed." \ No newline at end of file diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index 28577882c..47dba6df0 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -61,7 +61,6 @@ def apply_chat_template(request_json): # Parse the JSON request request = json.loads(request_json) key = request.pop("key") - print("mhg", key, flush=True) tokenizer = _tokenizer_cache.get(key) if tokenizer is None: raise RuntimeError(f"Tokenizer with key {key} not found in cache") From 3c23b593992b303080c9163a0b519fbbe073e34f Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 8 Jan 2026 13:49:39 +0000 Subject: [PATCH 10/12] edit Signed-off-by: HyunKyun Moon --- .github/workflows/ci-examples.yaml | 2 +- .github/workflows/ci-pr-checks.yaml | 2 +- pkg/preprocessing/chat_completions/setup.sh | 2 ++ pkg/tokenization/pool.go | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-examples.yaml b/.github/workflows/ci-examples.yaml index 981a878c5..3d823de57 100644 --- a/.github/workflows/ci-examples.yaml +++ b/.github/workflows/ci-examples.yaml @@ -50,4 +50,4 @@ jobs: run: chmod +x hack/verify-examples.sh - name: Run verify-examples.sh - run: ./hack/verify-examples.sh + run: sudo ./hack/verify-examples.sh diff --git a/.github/workflows/ci-pr-checks.yaml b/.github/workflows/ci-pr-checks.yaml index cdf3b515b..07682c893 100644 --- a/.github/workflows/ci-pr-checks.yaml +++ b/.github/workflows/ci-pr-checks.yaml @@ -75,7 +75,7 @@ jobs: - name: Run make build shell: bash run: | - make build + sudo make build - name: Run make test shell: bash diff --git a/pkg/preprocessing/chat_completions/setup.sh b/pkg/preprocessing/chat_completions/setup.sh index 9a42f792e..784e0099a 100755 --- a/pkg/preprocessing/chat_completions/setup.sh +++ b/pkg/preprocessing/chat_completions/setup.sh @@ -1,4 +1,6 @@ #!/bin/bash +# https://docs.vllm.ai/en/v0.8.4/getting_started/installation/cpu.html + set -e # 1. Skip if vllm is already installed diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index e6ebe2601..c07635ae7 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -217,6 +217,7 @@ func (pool *Pool) workerLoop(_ int) { // processTask tokenizes the prompt and updates the indexer. // It sends exactly one response (success or error) if ResultCh is provided. func (pool *Pool) processTask(task Task) error { + // https://github.com/vllm-project/vllm/blob/v0.11.2/vllm/entrypoints/openai/protocol.py#L1127 addSpecialToken := true if task.RenderReq != nil { var err error @@ -225,6 +226,7 @@ func (pool *Pool) processTask(task Task) error { log.Log.Error(err, "failed to render chat template") return err } + // https://github.com/vllm-project/vllm/blob/v0.11.2/vllm/entrypoints/openai/protocol.py#L613 addSpecialToken = false } From fff3be79c6074fdac1157944806cca402a2e672c Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 8 Jan 2026 13:58:41 +0000 Subject: [PATCH 11/12] edit Signed-off-by: HyunKyun Moon --- .github/workflows/ci-pr-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-pr-checks.yaml b/.github/workflows/ci-pr-checks.yaml index 07682c893..dc473b2ca 100644 --- a/.github/workflows/ci-pr-checks.yaml +++ b/.github/workflows/ci-pr-checks.yaml @@ -80,4 +80,4 @@ jobs: - name: Run make test shell: bash run: | - make test + sudo make test From 1e462fb6e47a921109d39ef64bea3a40896167c1 Mon Sep 17 00:00:00 2001 From: HyunKyun Moon Date: Thu, 8 Jan 2026 15:12:49 +0000 Subject: [PATCH 12/12] add path Signed-off-by: HyunKyun Moon --- .github/workflows/ci-pr-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-pr-checks.yaml b/.github/workflows/ci-pr-checks.yaml index dc473b2ca..288c31e81 100644 --- a/.github/workflows/ci-pr-checks.yaml +++ b/.github/workflows/ci-pr-checks.yaml @@ -80,4 +80,4 @@ jobs: - name: Run make test shell: bash run: | - sudo make test + sudo PATH="/root/.local/bin:$PATH" make test