Skip to content

Refactor: Replace daulet/tokenizers with vLLM#221

Closed
hyeongyun0916 wants to merge 5 commits into
llm-d:mainfrom
moreh-dev:vllm
Closed

Refactor: Replace daulet/tokenizers with vLLM#221
hyeongyun0916 wants to merge 5 commits into
llm-d:mainfrom
moreh-dev:vllm

Conversation

@hyeongyun0916
Copy link
Copy Markdown
Collaborator

This PR refactors the tokenization system to use vLLM's tokenizer wrapper instead of the daulet/tokenizers library. The changes streamline the tokenization pipeline by removing the external tokenizer dependency and consolidating chat template rendering and encoding through vLLM's unified API.

https://llm-d.slack.com/archives/C0A0SU5J68Y/p1764153758005369

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the tokenization system by replacing the daulet/tokenizers library with vLLM's tokenizer wrapper. The changes unify chat template rendering and token encoding through vLLM's API, eliminating the need for separate tokenizer downloads and simplifying the build process.

Key Changes:

  • Removed daulet/tokenizers dependency and replaced with vLLM (>=0.11.0) via Python CGO bindings
  • Refactored tokenizer interface: RenderChatTemplateApplyChatTemplate, updated Encode signature
  • Consolidated Python wrappers: render_jinja_template_wrapper.pytokenizer_wrapper.py using vLLM's get_tokenizer
  • Updated build configuration to remove tokenizer binary downloads from Makefile and Dockerfile

Reviewed changes

Copilot reviewed 27 out of 29 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
go.mod, go.sum Removed daulet/tokenizers dependency, added new transitive deps from controller-runtime
pkg/preprocessing/chat_completions/requirements.txt Replaced transformers/torch deps with vllm>=0.11.0
pkg/tokenization/tokenizer.go Refactored HF/Local tokenizers to use vLLM via CGO; removed provider abstractions
pkg/preprocessing/chat_completions/tokenizer_wrapper.py New Python wrapper using vLLM's get_tokenizer for unified tokenization
pkg/preprocessing/chat_completions/cgo_functions.* Updated CGO bindings: renamed functions from render_jinja_template/get_model_chat_template to apply_chat_template/encode
pkg/tokenization/pool.go Updated to use new ApplyChatTemplate method and EncodeRequest struct
tests/e2e/redis_mock/e2e_test.go Updated test setup to use composite tokenizer; removed SetTokenizer calls
Makefile, Dockerfile Removed tokenizer binary download steps and simplified CGO flags

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pkg/tokenization/tokenizer.go Outdated
Comment thread pkg/tokenization/tokenizer.go
Comment thread pkg/preprocessing/chat_completions/tokenizer_wrapper.py
Comment thread pkg/preprocessing/chat_completions/tokenizer_wrapper.py Outdated
Comment thread pkg/tokenization/pool.go
Comment thread pkg/preprocessing/chat_completions/cgo_functions.go Outdated
Comment thread pkg/tokenization/tokenizer_test.go
Comment thread pkg/preprocessing/chat_completions/tokenizer_wrapper.py Outdated
@vMaroon
Copy link
Copy Markdown
Member

vMaroon commented Dec 18, 2025

Thanks @hyeongyun0916 - will review ASAP.

On the CI issue, could you try adding a cleanup step? https://github.com/marketplace/actions/free-disk-space-ubuntu
Feel free to defer.

hyeongyun0916 and others added 3 commits December 19, 2025 06:45
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Hyunkyun Moon <mhg5303@gmail.com>
Comment thread pkg/tokenization/pool.go
log.Log.Error(err, "failed to render chat template", "modelName", task.ModelName)
return err
}
addSpecialTokens = false
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

hyeongyun0916 commented Dec 19, 2025

On the CI issue, could you try adding a cleanup step? https://github.com/marketplace/actions/free-disk-space-ubuntu
Feel free to defer.

Fixed the CI failure by replacing the vllm package with the CPU-only version to save space.

tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir)
_tokenizer_cache[cache_key] = tokenizer

request["tokenize"] = False
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I've set tokenize to False to support completions instead of chat completions. It's worth considering unify this into a single process later on.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't gotten to review to depth yet - but the fix that removes BOS duplicates would make it work for chat, right?

Copy link
Copy Markdown
Collaborator Author

@hyeongyun0916 hyeongyun0916 Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the BOS duplicates, that part is specifically related to this discussion: #221 (comment)

As for the request["tokenize"] = False setting:

Currently, Completion only performs encode.
ChatCompletion performs applyChatTemplate followed by encode.

To ensure both follow the same encoding logic, I’ve configured it so that in this stage, we only perform applyChatTemplate without the actual encoding. This allows the subsequent encoding step to be handled consistently.

Copy link
Copy Markdown
Collaborator

@sagearc sagearc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Note that the llmd kv cache now supports tokenizer initialization at startup for a single model. Unless we intend to support multiple tokenizer revisions simultaneously, I think we can simplify this by using a single tokenizer instance, ideally initialized at startup to catch unhealthy system or tokenizer configurations early.

Also, what do you think about splitting this PR into multiple, more focused ones? For example:

  • Add support for vLLM tokenizers (without removing daulet tokenizers yet)
  • Switch chat templating to use vLLM instead of the transformers library
  • Remove deprecated tokenizers (if needed)

Comment on lines +137 to +144
lock = _get_tokenizer_cache_lock()
with lock:
cache_key = f"{model_name}:{revision or 'main'}:{is_local}"
tokenizer = _tokenizer_cache.get(cache_key)
if tokenizer is None:
os.environ["HF_TOKEN"] = token
tokenizer = get_tokenizer(model_name, trust_remote_code=True, revision=revision, download_dir=download_dir)
_tokenizer_cache[cache_key] = tokenizer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to acquire the cache lock on every access, even after the tokenizer is initialized and in memory? If tokenizer.encode / tokenizer(...) is thread-safe here, as in the daulet tokenizers that rely on stateless HF tokenizers, we could likely restrict locking to the initialization path only and avoid holding the lock during encode calls.

Copy link
Copy Markdown
Collaborator Author

@hyeongyun0916 hyeongyun0916 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your point regarding the lock. I initially kept the cache lock as it was in the existing implementation to avoid any sudden structural changes.

Since I'm planning to split this PR as you suggested, would it be better to remove the lock for the encode and chatTemplate in the new PR? I'd like to hear your thoughts on this before I start refactoring.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should work. I think in most cases vllm relies on the underlying transformers tokenizer, not sure about any side effects.

Are you sure it’s thread-safe? It would probably be best to run a benchmark before and after the change to confirm there’s no regression.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding thread-safety, get_tokenizer in vLLM already utilizes get_lock internally to prevent race conditions during the download and initialization phase. For the subsequent calls, I believe the underlying tokenizer is stateless, but as you suggested, I will run a concurrency benchmark to confirm there are no regressions or race conditions when the lock is removed.

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

hyeongyun0916 commented Dec 22, 2025

Add support for vLLM tokenizers (without removing daulet tokenizers yet)

Regarding the first step, do you mean migrating only the encode functionality to vLLM first, without removing daulet yet?

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

Note that the llmd kv cache now supports tokenizer initialization at startup for a single model

I noticed that in the existing code, daulet seems to be initialized as a single tokenizer, but wrapper.py doesn't appear to follow that same pattern. That is why I implemented it this way for now.

As I refactor and split the PR, I have a few questions to align with your direction:

  1. Would it be better to update wrapper.py to use a single tokenizer instance initialized at startup as well?
  2. If we stick to a single model instance, should the system return an error if a request for a different model is received?
  3. Should we also restrict processTokenizerFile to initialize only for that specific model?

I would appreciate your guidance on these points so I can incorporate them into the split PRs.

@sagearc
Copy link
Copy Markdown
Collaborator

sagearc commented Dec 22, 2025

Regarding the first step, do you mean migrating only the encode functionality to vLLM first, without removing daulet yet?

Yeah, maybe even without any support for chat template rendering in that tokenizer (e.g., returning an “unimplemented” error), or by using the existing Transformers chat template instead. Does that make sense?

@sagearc
Copy link
Copy Markdown
Collaborator

sagearc commented Dec 22, 2025

I noticed that in the existing code, daulet seems to be initialized as a single tokenizer, but wrapper.py doesn't appear to follow that same pattern. That is why I implemented it this way for now.

As I refactor and split the PR, I have a few questions to align with your direction:

  1. Would it be better to update wrapper.py to use a single tokenizer instance initialized at startup as well?
  2. If we stick to a single model instance, should the system return an error if a request for a different model is received?
  3. Should we also restrict processTokenizerFile to initialize only for that specific model?

I would appreciate your guidance on these points so I can incorporate them into the split PRs.

  1. I believe so, this aligns with the bullets suggested in Section 2.2 of the LoRA support RFC.
  2. We shouldn’t return an error, since vLLM reuses the model field in the request to specify the lora adapter name (assuming a single base model). In this case, the model value in the request may differ from the base model name and instead represent the lora adapter. See Section 1.1 of the RFC above.
  3. This is possible, but I think we should focus first on the core changes required to migrate to vLLM tokenization and chat template rendering. Restricting processTokenizerFile could be a follow-up improvement later.

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

I understand your points. I'll split this PR as discussed and apply the changes you suggested. Thanks for the guidance!

@vMaroon
Copy link
Copy Markdown
Member

vMaroon commented Dec 26, 2025

Thank you @hyeongyun0916 @sagearc

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

hyeongyun0916 commented Dec 27, 2025

Also, what do you think about splitting this PR into multiple, more focused ones? For example:
Add support for vLLM tokenizers (without removing daulet tokenizers yet)
Switch chat templating to use vLLM instead of the transformers library
Remove deprecated tokenizers (if needed)

Regarding the PR split, I’d like to slightly adjust the order and proceed as follows:

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

Switch to using the vLLM tokenizer instead of the daulet tokenizer.

#254

@hyeongyun0916
Copy link
Copy Markdown
Collaborator Author

Regarding the PR split, I’d like to slightly adjust the order and proceed as follows:

Actually, since the tasks intended for this PR have already been merged through separate, more focused PRs, I'm going to close this one now.

Thanks for the great suggestion on splitting them up—it definitely made the process much smoother!

guygir pushed a commit to guygir/llm-d-kv-cache-manager that referenced this pull request Apr 20, 2026
….mk (llm-d#221)

- fixed typos
- added gitaction for typos

Signed-off-by: Jooho Lee <jlee@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants