Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove threadsafe #2907

Merged
merged 17 commits into from
Jan 3, 2025
83 changes: 44 additions & 39 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import asyncio
import csv
import os
import time
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import List, Union

import numpy as np
Expand All @@ -24,8 +23,9 @@
os.environ['TM_LOG_LEVEL'] = 'ERROR'


def infer(model, session_id: int, input_ids: List,
gen_config: GenerationConfig, test_round: int, que: Queue):
async def infer(model, session_id: int, input_ids: List,
gen_config: GenerationConfig, test_round: int,
que: asyncio.Queue):
if session_id == 1:
pbar = tqdm(total=test_round)
chatbot = model.create_instance()
Expand All @@ -47,12 +47,12 @@ def infer(model, session_id: int, input_ids: List,
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
for outputs in chatbot.stream_infer(session_id,
input_ids,
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
stream_output=True):
async for outputs in chatbot.async_stream_infer(session_id,
input_ids,
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
stream_output=True):
n_token = outputs.num_token
now = time.perf_counter()
if n_prev_token != n_token:
Expand All @@ -61,47 +61,50 @@ def infer(model, session_id: int, input_ids: List,
prev = now
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
chatbot.end(session_id)
await chatbot.async_end(session_id)
if session_id == 1:
pbar.update(1)

assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
que.put((session_id, stats))
await que.put((session_id, stats))


def warmup(model, concurrency: int, input_ids: List[int], warmup_round: int,
gen_config: GenerationConfig):
gen_config: GenerationConfig, event_loop: asyncio.BaseEventLoop):
if not warmup_round:
return

print('start to warmup ...')

def _infer(model, session_id):
async def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
for _ in chatbot.stream_infer(session_id,
input_ids=input_ids,
sequence_start=True,
sequence_end=True,
ignore_eos=True,
gen_config=gen_config):
async for _ in chatbot.async_stream_infer(session_id,
input_ids=input_ids,
sequence_start=True,
sequence_end=True,
ignore_eos=True,
gen_config=gen_config):
continue
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
chatbot.end(session_id)
await chatbot.async_end(session_id)

_start = time.perf_counter()
procs = []

# start threads
tasks = []
for i in range(concurrency):
proc = Thread(target=_infer, args=(model, i + 1), daemon=True)
procs.append(proc)
proc.start()
task = _infer(model, i + 1)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

for proc in procs:
proc.join()
event_loop.run_until_complete(_gather_tasks(tasks))

_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
Expand All @@ -125,31 +128,34 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path, engine_config)

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist()
warmup(tm_model, concurrency, input_ids, warmup_round, gen_config)
warmup(tm_model, concurrency, input_ids, warmup_round, gen_config,
event_loop)

que = Queue()
procs = []
que = asyncio.Queue()
_start = time.perf_counter()

tasks = []
for i in range(concurrency):
proc = Thread(target=infer,
args=(tm_model, i + 1, input_ids, gen_config, test_round,
que))
procs.append(proc)
proc.start()
task = infer(tm_model, i + 1, input_ids, gen_config, test_round, que)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

for proc in procs:
proc.join()
event_loop.run_until_complete(_gather_tasks(tasks))

_end = time.perf_counter()
elapsed_time = _end - _start

token_latency_stats = []
while not que.empty():
_, _stats = que.get()
_, _stats = que.get_nowait()
token_latency_stats += _stats

# The shape is [concurrency*test_round, output_seqlen]
Expand Down Expand Up @@ -426,7 +432,6 @@ def main():
block_size=args.cache_block_seq_len,
session_len=session_len,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
dtype=args.dtype,
Expand Down
78 changes: 78 additions & 0 deletions docs/en/advance/pytorch_multithread.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PyTorchEngine Multithread

We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example:

```python
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

model_path = 'Llama-3.2-1B-Instruct'
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())

async def _gather_output():
tasks = [
pipe.async_batch_infer('Hakuna Matata'),
pipe.async_batch_infer('giraffes are heartless creatures'),
]
return await asyncio.gather(*tasks)

output = asyncio.run(_gather_output())
print(output[0].text)
print(output[1].text)
```

If you do need multithreading, it would be easy to warp it like below:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzhangzz After PR #2968, can users use pipeline api in multithreading?


```python
import threading
from queue import Queue
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

model_path = 'Llama-3.2-1B-Instruct'


async def _batch_infer(inque: Queue, outque: Queue, pipe):
while True:
if inque.empty():
await asyncio.sleep(0)
continue

input = inque.get_nowait()
output = await pipe.async_batch_infer(input)
outque.put(output)


def server(inques, outques):
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
for inque, outque in zip(inques, outques):
event_loop.create_task(_batch_infer(inque, outque, pipe))
event_loop.run_forever()

def client(inque, outque, message):
inque.put(message)
print(outque.get().text)


inques = [Queue(), Queue()]
outques = [Queue(), Queue()]

t_server = threading.Thread(target=server, args=(inques, outques))
t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))

t_server.start()
t_client0.start()
t_client1.start()

t_client0.join()
t_client1.join()
```

> \[!WARNING\]
> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Documentation
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
advance/pytorch_multithread.md

.. toctree::
:maxdepth: 1
Expand Down
78 changes: 78 additions & 0 deletions docs/zh_cn/advance/pytorch_multithread.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PyTorchEngine 多线程推理
grimoire marked this conversation as resolved.
Show resolved Hide resolved

自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发,比如:

```python
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

model_path = 'Llama-3.2-1B-Instruct'
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())

async def _gather_output():
tasks = [
pipe.async_batch_infer('Hakuna Matata'),
pipe.async_batch_infer('giraffes are heartless creatures'),
]
return await asyncio.gather(*tasks)

output = asyncio.run(_gather_output())
print(output[0].text)
print(output[1].text)
```

如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。

```python
import threading
from queue import Queue
import asyncio
from lmdeploy import pipeline, PytorchEngineConfig

model_path = 'Llama-3.2-1B-Instruct'


async def _batch_infer(inque: Queue, outque: Queue, pipe):
while True:
if inque.empty():
await asyncio.sleep(0)
continue

input = inque.get_nowait()
output = await pipe.async_batch_infer(input)
outque.put(output)


def server(inques, outques):
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
for inque, outque in zip(inques, outques):
event_loop.create_task(_batch_infer(inque, outque, pipe))
event_loop.run_forever()

def client(inque, outque, message):
inque.put(message)
print(outque.get().text)


inques = [Queue(), Queue()]
outques = [Queue(), Queue()]

t_server = threading.Thread(target=server, args=(inques, outques))
t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))

t_server.start()
t_client0.start()
t_client1.start()

t_client0.join()
t_client1.join()
```

> \[!WARNING\]
> 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
advance/pytorch_multithread.md

.. toctree::
:maxdepth: 1
Expand Down
Loading
Loading