Skip to content

Commit

Permalink
Remove threadsafe (#2907)
Browse files Browse the repository at this point in the history
* remove threadsafe

* optimize performance

* 22.4

* 22.5

* delete jsonl

* add docs

* fix link

* rst

* remove sleep req step

* remove scheduler sleep

* fix ut

* recovery async engine
  • Loading branch information
grimoire authored Jan 3, 2025
1 parent 9e593e7 commit aabc90d
Show file tree
Hide file tree
Showing 14 changed files with 523 additions and 664 deletions.
83 changes: 44 additions & 39 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import asyncio
import csv
import os
import time
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import List, Union

import numpy as np
Expand All @@ -24,8 +23,9 @@
os.environ['TM_LOG_LEVEL'] = 'ERROR'


def infer(model, session_id: int, input_ids: List,
gen_config: GenerationConfig, test_round: int, que: Queue):
async def infer(model, session_id: int, input_ids: List,
gen_config: GenerationConfig, test_round: int,
que: asyncio.Queue):
if session_id == 1:
pbar = tqdm(total=test_round)
chatbot = model.create_instance()
Expand All @@ -47,12 +47,12 @@ def infer(model, session_id: int, input_ids: List,
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
for outputs in chatbot.stream_infer(session_id,
input_ids,
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
stream_output=True):
async for outputs in chatbot.async_stream_infer(session_id,
input_ids,
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
stream_output=True):
n_token = outputs.num_token
now = time.perf_counter()
if n_prev_token != n_token:
Expand All @@ -61,47 +61,50 @@ def infer(model, session_id: int, input_ids: List,
prev = now
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
chatbot.end(session_id)
await chatbot.async_end(session_id)
if session_id == 1:
pbar.update(1)

assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
que.put((session_id, stats))
await que.put((session_id, stats))


def warmup(model, concurrency: int, input_ids: List[int], warmup_round: int,
gen_config: GenerationConfig):
gen_config: GenerationConfig, event_loop: asyncio.BaseEventLoop):
if not warmup_round:
return

print('start to warmup ...')

def _infer(model, session_id):
async def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
for _ in chatbot.stream_infer(session_id,
input_ids=input_ids,
sequence_start=True,
sequence_end=True,
ignore_eos=True,
gen_config=gen_config):
async for _ in chatbot.async_stream_infer(session_id,
input_ids=input_ids,
sequence_start=True,
sequence_end=True,
ignore_eos=True,
gen_config=gen_config):
continue
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
chatbot.end(session_id)
await chatbot.async_end(session_id)

_start = time.perf_counter()
procs = []

# start threads
tasks = []
for i in range(concurrency):
proc = Thread(target=_infer, args=(model, i + 1), daemon=True)
procs.append(proc)
proc.start()
task = _infer(model, i + 1)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

for proc in procs:
proc.join()
event_loop.run_until_complete(_gather_tasks(tasks))

_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
Expand All @@ -125,31 +128,34 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path, engine_config)

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist()
warmup(tm_model, concurrency, input_ids, warmup_round, gen_config)
warmup(tm_model, concurrency, input_ids, warmup_round, gen_config,
event_loop)

que = Queue()
procs = []
que = asyncio.Queue()
_start = time.perf_counter()

tasks = []
for i in range(concurrency):
proc = Thread(target=infer,
args=(tm_model, i + 1, input_ids, gen_config, test_round,
que))
procs.append(proc)
proc.start()
task = infer(tm_model, i + 1, input_ids, gen_config, test_round, que)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

for proc in procs:
proc.join()
event_loop.run_until_complete(_gather_tasks(tasks))

_end = time.perf_counter()
elapsed_time = _end - _start

token_latency_stats = []
while not que.empty():
_, _stats = que.get()
_, _stats = que.get_nowait()
token_latency_stats += _stats

# The shape is [concurrency*test_round, output_seqlen]
Expand Down Expand Up @@ -426,7 +432,6 @@ def main():
block_size=args.cache_block_seq_len,
session_len=session_len,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
dtype=args.dtype,
Expand Down
78 changes: 78 additions & 0 deletions docs/en/advance/pytorch_multithread.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PyTorchEngine Multithread

We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example:

```python
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

model_path = 'Llama-3.2-1B-Instruct'
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())

async def _gather_output():
tasks = [
pipe.async_batch_infer('Hakuna Matata'),
pipe.async_batch_infer('giraffes are heartless creatures'),
]
return await asyncio.gather(*tasks)

output = asyncio.run(_gather_output())
print(output[0].text)
print(output[1].text)
```

If you do need multithreading, it would be easy to warp it like below:

```python
import threading
from queue import Queue
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

model_path = 'Llama-3.2-1B-Instruct'


async def _batch_infer(inque: Queue, outque: Queue, pipe):
while True:
if inque.empty():
await asyncio.sleep(0)
continue

input = inque.get_nowait()
output = await pipe.async_batch_infer(input)
outque.put(output)


def server(inques, outques):
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
for inque, outque in zip(inques, outques):
event_loop.create_task(_batch_infer(inque, outque, pipe))
event_loop.run_forever()

def client(inque, outque, message):
inque.put(message)
print(outque.get().text)


inques = [Queue(), Queue()]
outques = [Queue(), Queue()]

t_server = threading.Thread(target=server, args=(inques, outques))
t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))

t_server.start()
t_client0.start()
t_client1.start()

t_client0.join()
t_client1.join()
```

> \[!WARNING\]
> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Documentation
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
advance/pytorch_multithread.md

.. toctree::
:maxdepth: 1
Expand Down
78 changes: 78 additions & 0 deletions docs/zh_cn/advance/pytorch_multithread.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PyTorchEngine 多线程推理

[PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口****协程**来实现高并发,比如:

```python
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

model_path = 'Llama-3.2-1B-Instruct'
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())

async def _gather_output():
tasks = [
pipe.async_batch_infer('Hakuna Matata'),
pipe.async_batch_infer('giraffes are heartless creatures'),
]
return await asyncio.gather(*tasks)

output = asyncio.run(_gather_output())
print(output[0].text)
print(output[1].text)
```

如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。

```python
import threading
from queue import Queue
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

model_path = 'Llama-3.2-1B-Instruct'


async def _batch_infer(inque: Queue, outque: Queue, pipe):
while True:
if inque.empty():
await asyncio.sleep(0)
continue

input = inque.get_nowait()
output = await pipe.async_batch_infer(input)
outque.put(output)


def server(inques, outques):
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
for inque, outque in zip(inques, outques):
event_loop.create_task(_batch_infer(inque, outque, pipe))
event_loop.run_forever()

def client(inque, outque, message):
inque.put(message)
print(outque.get().text)


inques = [Queue(), Queue()]
outques = [Queue(), Queue()]

t_server = threading.Thread(target=server, args=(inques, outques))
t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))

t_server.start()
t_client0.start()
t_client1.start()

t_client0.join()
t_client1.join()
```

> \[!WARNING\]
> 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
advance/pytorch_multithread.md

.. toctree::
:maxdepth: 1
Expand Down
Loading

0 comments on commit aabc90d

Please sign in to comment.