Skip to content

[Bugfix] Fix RuntimeError: Already borrowed by adding thread-safe Hugging Face fast-tokenizer wrappers#41181

Open
yzong-rh wants to merge 1 commit intovllm-project:mainfrom
yzong-rh:yzong-rh/thread-safe-tok
Open

[Bugfix] Fix RuntimeError: Already borrowed by adding thread-safe Hugging Face fast-tokenizer wrappers#41181
yzong-rh wants to merge 1 commit intovllm-project:mainfrom
yzong-rh:yzong-rh/thread-safe-tok

Conversation

@yzong-rh
Copy link
Copy Markdown
Contributor

@yzong-rh yzong-rh commented Apr 29, 2026

Purpose

Thread-safe HuggingFace fast tokenizer wrapper for the RuntimeError: Already borrowed concurrency issue reported in #40949 .

  • Keeps track of a local Rust tokenizer backend deepcopy for each thread.

Fix concurrency issues with bad_words sampling param. Removes the need to use deepcopy for multimodal processor.

Test Plan

pytest tests/models/multimodal/processing/test_common.py 

benchmark.py

python benchmark.py --prompt-length 512 --iterations 5000 --warmup 5000 --mixed
vllm serve Qwen/Qwen3-4B-Instruct-2507-FP8 --renderer_num_workers 4 --api-server-count=4

vllm serve deepseek-ai/DeepSeek-V4-Flash --renderer_num_workers 4   --api-server-count=4 --trust-remote-code --tensor-parallel-size=2 --max-model-len 4096 --kv-cache-dtype fp8

vllm serve deepseek-ai/DeepSeek-OCR --renderer_num_workers 4 --mm-processor-cache-gb 0 --api-server-count=4

vllm serve Qwen/Qwen-VL-Chat --renderer_num_workers 4 --mm-processor-cache-gb 0 --api-server-count=4 --trust-remote-code --hf-overrides '{"architectures": ["QwenVLForConditionalGeneration"]}

stress_send.py

python stress_send.py -n 5000 -c 500 [--mm] [--bad-word]

Test Result

Unite test pass

Benchmark:

Loading tokenizer from meta-llama/Llama-3.1-8B-Instruct …
Prompt: 522 tokens, 2236 chars
Config: iterations=5000  warmup=5000  threads=[1, 2, 8]  mixed=True  truncation_max_length=1044


=== 1 thread(s) ===
  raw (no wrapper)                mean=0.487 ms  median=0.486 ms  p99=0.507 ms  total=2433.1 ms  wall=2434.7 ms  n=5000
  lock wrapper                    mean=0.483 ms  median=0.482 ms  p99=0.500 ms  total=2414.5 ms  wall=2415.7 ms  n=5000
  copy wrapper (threading.local)  mean=0.530 ms  median=0.527 ms  p99=0.594 ms  total=2652.4 ms  wall=2673.2 ms  n=5000
  copy wrapper (dict)             mean=0.529 ms  median=0.528 ms  p99=0.552 ms  total=2644.6 ms  wall=2646.1 ms  n=5000
  queue wrapper                   mean=0.531 ms  median=0.529 ms  p99=0.585 ms  total=2655.4 ms  wall=2657.1 ms  n=5000

=== 2 thread(s) ===
  raw (no wrapper)                failures=2449
  lock wrapper                    mean=0.816 ms  median=0.517 ms  p99=0.758 ms  total=4080.9 ms  wall=2573.6 ms  n=5000
  copy wrapper (threading.local)  mean=0.561 ms  median=0.554 ms  p99=0.679 ms  total=2805.8 ms  wall=1440.9 ms  n=5000
  copy wrapper (dict)             mean=0.533 ms  median=0.532 ms  p99=0.592 ms  total=2667.1 ms  wall=1335.7 ms  n=5000
  queue wrapper                   mean=0.566 ms  median=0.557 ms  p99=0.758 ms  total=2828.0 ms  wall=1425.4 ms  n=5000

=== 8 thread(s) ===
  raw (no wrapper)                failures=2467
  lock wrapper                    mean=2.397 ms  median=0.492 ms  p99=0.574 ms  total=11982.8 ms  wall=2508.3 ms  n=5000
  copy wrapper (threading.local)  mean=0.677 ms  median=0.540 ms  p99=2.235 ms  total=3384.8 ms  wall=498.6 ms  n=5000
  copy wrapper (dict)             mean=0.572 ms  median=0.564 ms  p99=0.721 ms  total=2858.2 ms  wall=359.5 ms  n=5000
  queue wrapper                   mean=0.549 ms  median=0.545 ms  p99=0.636 ms  total=2742.9 ms  wall=344.6 ms  n=5000
Qwen3-4B-Instruct-2507-FP8

Before:

Sending 5000 text-only requests (concurrency=500) to http://localhost:8000
Model: Qwen/Qwen3-4B-Instruct-2507-FP8  mode=text
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 2.61s  (1916.1 req/s)
Latency  : p50=0.296s  p99=0.475

After:

Sending 5000 text-only requests (concurrency=500) to http://localhost:8000 bad_words=['hello world']
Model: Qwen/Qwen3-4B-Instruct-2507-FP8  mode=text
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 2.56s  (1951.1 req/s)
Latency  : p50=0.302s  p99=0.459s
DeepSeek-V4-Flash

Before:

Sending 5000 text-only requests (concurrency=500) to http://localhost:8000
Model: deepseek-ai/DeepSeek-V4-Flash  mode=text
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 13.53s  (369.5 req/s)
Latency  : p50=1.335s  p99=2.478s

After:

Sending 5000 text-only requests (concurrency=500) to http://localhost:8000 bad_words=['hello world']
Model: deepseek-ai/DeepSeek-V4-Flash  mode=text
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 12.95s  (386.1 req/s)
Latency  : p50=1.340s  p99=2.374s
Qwen-VL-Chat

Before:

Sending 5000 multimodal requests (concurrency=500) to http://localhost:8000
Model: Qwen/Qwen-VL-Chat  mode=multimodal
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 10.30s  (485.6 req/s)
Latency  : p50=1.088s  p99=1.420s

After:

Sending 5000 multimodal requests (concurrency=500) to http://localhost:8000 bad_words=['hello world']
Model: Qwen/Qwen-VL-Chat  mode=multimodal
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 10.27s  (487.0 req/s)
Latency  : p50=1.108s  p99=1.448s
DeepSeek-OCR

Before:

Sending 5000 multimodal requests (concurrency=500) to http://localhost:8000
Model: deepseek-ai/DeepSeek-OCR  mode=multimodal
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 11.14s  (448.9 req/s)
Latency  : p50=1.063s  p99=1.565s

After:

Sending 5000 multimodal requests (concurrency=500) to http://localhost:8000 bad_words=['hello world']
Model: deepseek-ai/DeepSeek-OCR  mode=multimodal
==================================================
Requests : 5000/5000 ok, 0 errors
Wall time: 10.31s  (484.9 req/s)
Latency  : p50=0.985s  p99=1.288s

AI Assistance

Made with Cursor

cc @sfeng33 @bbrowning


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@mergify mergify Bot added multi-modality Related to multi-modality (#4194) bug Something isn't working labels Apr 29, 2026
@yzong-rh
Copy link
Copy Markdown
Contributor Author

yzong-rh commented Apr 29, 2026

Note that I only added to the main HF Tokenizer path. Before readying, need to add to Deepseek and other paths too.

Updated other tokenizer paths that use HF Fast tokenizer.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements thread-safe wrappers for Hugging Face tokenizers, providing both lock-based and thread-local copy implementations controlled by an environment variable. It also removes previous deepcopy workarounds and retry logic for tokenizer access. Feedback identifies a regression in default thread safety for multimodal models, a missing batch_decode override in the lock-based wrapper, and the need to remove commented-out code. Additionally, it is recommended to use the VLLM_ prefix for the new environment variable and replace print statements with the project's logging system.

Comment thread vllm/multimodal/processing/context.py Outdated
Comment thread vllm/renderers/base.py
Comment thread vllm/tokenizers/hf.py Outdated
Comment thread vllm/tokenizers/hf.py Outdated
Comment thread vllm/tokenizers/hf.py Outdated
Comment thread vllm/tokenizers/hf.py Outdated
Comment on lines +33 to +38
if not isinstance(tokenizer, PreTrainedTokenizerFast):
return tokenizer

thread_safe_tokenizer = copy.copy(tokenizer)
lock = threading.RLock()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have confirmed that having each thread use its own deepcopy(tokenizer) can resolve the "RuntimeError: Already borrowed", so there is no need to add a lock to protect it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_thread_safe_hf_tokenizer_w_lock doesn't deepcopy. This one only uses lock. I left it in for performance comparison. Will need to remove before readying

The deepcopy version is in get_thread_safe_hf_tokenizer_w_copy. This one doesn't use lock.

Both solve the already borrowed issue.

Copy link
Copy Markdown
Contributor Author

@yzong-rh yzong-rh Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see a huge perf difference between them even with --api-server-count=4. I need more profiling but if perf is similar I actually lean more towards the lock implementation as it handles mutation better (less unexpected behavior if tokenizer is mutated). Wdyt?

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid using lock in the case there is only one thread. Using deepcopy would be cleaner in terms of code IMO

Copy link
Copy Markdown
Collaborator

@noooop noooop Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#34789 (comment)

In the short term, (maybe) due to the Python GIL, setting --renderer_num_workers > 1 and using a thread pool provides almost no performance improvement, so setting it to 1 is reasonable.

But in the long term, we still want to improve performance through multithreading — for example, by upgrading to Python 3.14, or by finding some way to bypass the GIL. So personally, I still want to keep the --renderer_num_workers parameter.

Because we don't know exactly what is happening during the preprocessing stage, it's very difficult to optimize further. So, next month I plan to first add better observability for the entrypoints. Perhaps this will help us better understand what is actually going on.

PTAL #39979

Copy link
Copy Markdown
Collaborator

@noooop noooop Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing what we know so far:

  1. Offload blocking tokenizer ops to shared thread pool to unblock event loop [Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop #34789 [Frontend] Offload blocking preprocessing & postprocessing ops to thread pool for pooling entrypoints. #39763

  2. (maybe) due to the Python GIL, setting --renderer_num_workers > 1 and using a thread pool provides almost no performance improvement [Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop #34789 (comment)

  3. https://github.com/noooop/snippet/blob/main/benchmarks/thread_pool/ the thread pool can accelerate torch.mm and tokenizer.encode, but it cannot accelerate torch.mm.

  4. Hugging Face fast-tokenizer is not thread-safe, which may cause RuntimeError: Already borrowed. deepcopy(tokenizer) can resolve this. [Bugfix] Fix RuntimeError: Already borrowed that degrades VLM serving throughput under concurrent load. #36557

  5. Using deepcopy(tokenizer) on every call in a multithreaded environment will certainly introduce some overhead. We need a tokenizer pool. https://github.com/vllm-project/vllm/pull/28458/changes#r3159044061

Copy link
Copy Markdown
Collaborator

@noooop noooop Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/noooop/snippet/blob/main/benchmarks/thread_pool/preprocessing_tokenizer_encode.py

from concurrent.futures import ProcessPoolExecutor
tokenizer_encode n_workers: 1, e2e: 12.22839603600005
tokenizer_encode n_workers: 2, e2e: 6.675321902999713
tokenizer_encode n_workers: 4, e2e: 3.770936963999702
tokenizer_encode n_workers: 8, e2e: 2.363701287999902
tokenizer_encode n_workers: 16, e2e: 2.231783658000495

from concurrent.futures import ThreadPoolExecutor
tokenizer_encode n_workers: 1, e2e: 11.007459864000339
tokenizer_encode n_workers: 2, e2e: 5.999560164000286
tokenizer_encode n_workers: 4, e2e: 3.4470128889997795
tokenizer_encode n_workers: 8, e2e: 2.0952387249999447
tokenizer_encode n_workers: 16, e2e: 1.9648324209993007


ThreadPoolExecutor + deepcopy
tokenizer_encode n_workers: 1, e2e: 29.1953536609999
tokenizer_encode n_workers: 2, e2e: 18.48894192500029
tokenizer_encode n_workers: 4, e2e: 18.999746747000245
tokenizer_encode n_workers: 8, e2e: 18.900613595999857
tokenizer_encode n_workers: 16, e2e: 19.041412347000005

I believe the tokenizer benefits from ThreadPoolExecutor, but we cannot afford to use deepcopy on every execution, as it would introduce significant overhead. So we need to add a tokenizer pool to avoid repeated deepcopy.

Signed-off-by: Yifan Zong <yzong@redhat.com>
@yzong-rh yzong-rh force-pushed the yzong-rh/thread-safe-tok branch from 38224d7 to 347112f Compare April 30, 2026 00:36
@yzong-rh yzong-rh marked this pull request as ready for review April 30, 2026 02:13
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Comment thread vllm/tokenizers/hf.py
backend_tokenizer = thread_safe_tokenizer._tokenizer

# Concurrent dict insertion is safe here thanks to the GIL.
thread_local = {threading.get_ident(): copy.deepcopy(backend_tokenizer)}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can switch to threading.local() easily when we no longer rely on the GIL.

threading.local() uses a small lock internally (according to AI) and does not scale as well in microbenchmark.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a little less hackiness?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have read your previous comment more carefully. Queue() is quite clean and performant.

My only concern is that FastIncrementalDetokenizer (which runs on the main thread) uses the Rust backend directly. I'll try either patching _tokenizer's methods or make FastIncrementalDetokenizer go through the Python tokenizer.

def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.tokenizer: Tokenizer = tokenizer._tokenizer

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore FastIncrementalDetokenizer for now. As I understand, it runs on the core and will only run in a single thread for now.

Comment thread vllm/tokenizers/hf.py
@DarkLight1337
Copy link
Copy Markdown
Member

Can you rerun some of the benchmarks with this updated code?

@yzong-rh
Copy link
Copy Markdown
Contributor Author

yzong-rh commented Apr 30, 2026

Can you rerun some of the benchmarks with this updated code?

All the benchmarks are updated as of April 29. They were all rerun when I readied the PR.

Main branch: 0ab67c0
2 x B200

To be fair though there is quite a lot of run-to-run variance in the E2E runs. I mainly looked for perf regressions. The micro-benchmark is much more repeatable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working multi-modality Related to multi-modality (#4194)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants