From da11137703b3c823d6b36e47de51f6276665c89d Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Fri, 18 Oct 2024 15:39:33 +0800 Subject: [PATCH 1/3] Copy sglang/bench_serving.py to lmdeploy as serving benchmark script --- benchmark/profile_restful_api.py | 265 ------- benchmark/profile_serving.py | 1110 ++++++++++++++++++++++++++++++ 2 files changed, 1110 insertions(+), 265 deletions(-) delete mode 100644 benchmark/profile_restful_api.py create mode 100644 benchmark/profile_serving.py diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py deleted file mode 100644 index 524b302906..0000000000 --- a/benchmark/profile_restful_api.py +++ /dev/null @@ -1,265 +0,0 @@ -import csv -import json -import random -import time -from queue import Queue -from threading import Thread -from typing import List, Optional, Tuple - -import fire -import numpy as np -from tqdm import tqdm -from transformers import AutoTokenizer - -from lmdeploy.serve.openai.api_client import APIClient - - -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: AutoTokenizer, -) -> List[Tuple[str, int, int]]: - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data['conversations']) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) for data in dataset] - - # pre-sample to avoid go through all the dataset - dataset = random.sample(dataset, max(int(num_requests * 1.2), 1000)) - - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests - - -class Engine: - - def __init__(self, - server_addr: str, - tokenzier_path: str, - temperature: float = 0.8, - top_p: float = 1.0, - csv: str = '', - api_key: Optional[str] = None, - model_name: Optional[str] = None, - **kwargs): - self.tokenizer = AutoTokenizer.from_pretrained(tokenzier_path, - trust_remote_code=True) - self.server_addr = server_addr - self.temperature = temperature - self.top_p = top_p - self.csv = csv - self.api_key = api_key - client = APIClient(self.server_addr, api_key=self.api_key) - if model_name is None: - self.model_name = client.available_models[0] - print(f'using model: {self.model_name}\n') - else: - self.model_name = model_name - self.pbar = None - - def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, - stream_output: bool): - - stats = [] - client = APIClient(self.server_addr, api_key=self.api_key) - - for prompt, input_seqlen, output_seqlen in iter( - req_queue.get, [None, None, None]): - timestamps = [] - timestamps.append(time.perf_counter()) - for output in client.chat_completions_v1( - model=self.model_name, - messages=prompt, - temperature=self.temperature, - top_p=self.top_p, - n=1, - max_tokens=output_seqlen, - stream=stream_output, - session_id=session_id, - ignore_eos=True): - timestamps.append(time.perf_counter()) - - first_token_latency = np.round(timestamps[1] - timestamps[0], 3) - token_latency = np.round(timestamps[-1] - timestamps[0], 3) - # assert output.pop('finish_reason') == 'length', \ - # f'Error. session_id({session_id}) request {output_seqlen} ' \ - # f'tokens, but `finish_reason` is not `length`' - total_tokens = input_seqlen + output_seqlen - stats.append([ - first_token_latency, output_seqlen, output_seqlen, - total_tokens, token_latency - ]) - self.pbar.update(1) - - res_queue.put((session_id, stats)) - - def process_request(self, - requests, - concurrency: int = 1, - stream_output: bool = False): - res_queue = Queue() - req_queue = Queue() - threads = [] - - self.pbar = tqdm(total=len(requests)) - - # feed request to q - for req in requests: - req_queue.put(req) - for i in range(concurrency): - req_queue.put([None, None, None]) - - start = time.time() - - # start threads - for i in range(concurrency): - t = Thread(target=self._inference, - args=(req_queue, res_queue, i, stream_output)) - t.start() - threads.append(t) - - # wait for finish - for t in threads: - t.join() - - elapsed_time = time.time() - start - - stats = [] - while not res_queue.empty(): - session_id, _stats = res_queue.get() - if len(_stats) != 0: - stats.append(np.array(_stats)) - - stats = np.concatenate(stats).reshape(-1, 5) - - first_token_latency_min = np.min(stats[:, 0], axis=0) - first_token_latency_max = np.max(stats[:, 0], axis=0) - first_token_latency_ave = np.mean(stats[:, 0], axis=0) - completion_tokens = np.sum(stats[:, 1], axis=0) - request_output_tokens = np.sum(stats[:, 2], axis=0) - total_tokens = np.sum(stats[:, 3], axis=0) - prompt_tokens = total_tokens - completion_tokens - completion_token_throughput = completion_tokens / elapsed_time - total_token_throughput = total_tokens / elapsed_time - rps = len(requests) / elapsed_time - rpm = rps * 60 - - if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: - print(f'Did not generate requested number of tokens. ' - f'Request {request_output_tokens:.0f}, ' - f'but got {completion_tokens:.0f}') - - print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.3f}s\n') - if stream_output: - print(f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.3f}s, ' - f'{first_token_latency_max:.3f}s, ' - f'{first_token_latency_ave:.3f}s\n') - print( - f'number of prompt tokens: {prompt_tokens:.0f}\n' - f'number of completion tokens: {completion_tokens:.0f}\n' - f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa - f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa - f'RPS (request per second): {rps:.3f} req/s\n' - f'RPM (request per minute): {rpm:.3f} req/min\n' - f'{"-" * 50}\n') - - if self.csv: - with open(self.csv, 'w') as csvfile: - writer = csv.writer(csvfile) - writer.writerow([ - 'batch', 'num_prompts', 'RPS', 'RPM', 'FTL(ave)(s)', - 'FTL(min)(s)', 'FTL(max)(s)', 'throughput(out tok/s)', - 'throughput(total tok/s)' - ]) - writer.writerow([ - concurrency, - len(requests), f'{rps:.3f}', f'{rpm:.3f}', - f'{first_token_latency_ave:.3f}' if stream_output else '-', - f'{first_token_latency_min:.3f}' if stream_output else '-', - f'{first_token_latency_max:.3f}' if stream_output else '-', - f'{completion_token_throughput:.3f}', - f'{total_token_throughput:.3f}' - ]) - - -def main(server_addr: str, - tokenizer_path: str, - dataset: str, - api_key: Optional[str] = None, - model_name: Optional[str] = None, - concurrency: int = 128, - num_prompts: int = 5000, - top_p: float = 1.0, - temperature: float = 1.0, - stream_output: bool = False, - csv: str = './profile_api_server.csv', - seed: int = 0): - """Benchmark the request througput of api server. - - Args: - server_addr (str): http url of api_server with format http://0.0.0.0:0 - tokenizer_path (str): Path to the tokenizer model in localhost - dataset (str): Path to the dataset - concurrency (int, optional): Number of working threads to process the sampled prompts. - Defaults to 128. - num_prompts (int, optional): Number of prompts to process. Defaults to 5000. - top_p (float, optional): the set of most probable tokens with - probabilities that add up to top_p or higher - are kept for generation. Defaults to 1.0. - temperature (float, optional): The value used to modulate the next token probabilities. - Defaults to 1.0. - stream_output (bool, optional): Indicator for streaming output. Defaults to False. - csv (str, optional): The path to save the result. - seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0. - """ # noqa - if not server_addr.startswith('http://'): - print(f'[WARNING] server_addr of the api_server should ' - f'start with "http://", but got "{server_addr}"') - server_addr = 'http://' + server_addr.strip() - - random.seed(seed) - - engine = Engine(server_addr, - tokenizer_path, - top_p=top_p, - temperature=temperature, - csv=csv, - api_key=api_key, - model_name=model_name) - - requests = sample_requests(dataset, num_prompts, engine.tokenizer) - - engine.process_request(requests, concurrency, stream_output) - - -if __name__ == '__main__': - fire.Fire(main) diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py new file mode 100644 index 0000000000..963c00d490 --- /dev/null +++ b/benchmark/profile_serving.py @@ -0,0 +1,1110 @@ +# Modify from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_serving.py # noqa +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py # noqa +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py # noqa +"""Benchmark online serving with dynamic requests. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" # noqa + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerBase, PreTrainedTokenizerFast) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = '' + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field( + default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = '' + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix):] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith('generate_stream') + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + 'accumulate_tokens': True, + 'text_input': request_func_input.prompt, + 'temperature': 0.000001, + 'top_p': 1.0, + 'max_tokens': request_func_input.output_len, + 'stream': True, + 'min_length': request_func_input.output_len, + 'end_id': 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload['min_length'] + del payload['end_id'] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode('utf-8'), + 'data:') + + data = json.loads(chunk) + output.generated_text += data['text_output'] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or '' + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = ''.join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + 'completions' + ), "OpenAI Completions API URL must end with 'completions'." + + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + 'model': request_func_input.model, + 'prompt': prompt, + 'temperature': 0.0, + 'best_of': 1, + 'max_tokens': request_func_input.output_len, + 'stream': not args.disable_stream, + 'ignore_eos': not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = { + 'Authorization': f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = '' + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode('utf-8'), + 'data: ') + latency = time.perf_counter() - st + if chunk == '[DONE]': + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data['choices'][0]['text']: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data['choices'][0]['text'] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or '' + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = ''.join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_sglang_generate( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + 'text': prompt, + 'sampling_params': { + 'temperature': 0.0, + 'max_new_tokens': request_func_input.output_len, + 'ignore_eos': not args.disable_ignore_eos, + }, + 'stream': not args.disable_stream, + **request_func_input.extra_request_body, + } + headers = {} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = '' + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + # print(chunk_bytes) + + chunk = remove_prefix(chunk_bytes.decode('utf-8'), + 'data: ') + latency = time.perf_counter() - st + if chunk == '[DONE]': + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data['text']: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data['text'] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or '' + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = ''.join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_gserver( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + raise NotImplementedError() + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv('SGLANG_USE_MODELSCOPE', 'False').lower() == 'true': + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=['.*.pt', '.*.safetensors', '.*.bin'], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path.endswith( + '.json') or pretrained_model_name_or_path.endswith('.model'): + from sglang.srt.hf_transformers_utils import get_tokenizer + + return get_tokenizer(pretrained_model_name_or_path) + + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path): + pretrained_model_name_or_path = get_model( + pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, + trust_remote_code=True) + + +ASYNC_REQUEST_FUNCS = { + 'sglang': async_request_sglang_generate, + 'sglang-native': async_request_sglang_generate, + 'sglang-oai': async_request_openai_completions, + 'vllm': async_request_openai_completions, + 'lmdeploy': async_request_openai_completions, + 'trt': async_request_trt_llm, + 'gserver': async_request_gserver, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +SHAREGPT_URL = 'https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json' # noqa + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join('/tmp', url.split('/')[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f'Downloading from {url} to {filename}') + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get('content-length', 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, 'wb') as f, tqdm( + desc=filename, + total=total_size, + unit='B', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError('output_len too small') + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer.encode(prompt) + completion = dataset[i][1] + completion_token_ids = tokenizer.encode(completion) + prompt_len = len(prompt_token_ids) + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or (prompt_len + output_len > 2048 + and fixed_output_len is None): + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}') + print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}') + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to + # satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[:input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[:input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append( + (prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. + # This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode([ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ]) + input_requests.append( + (prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f'#Input tokens: {np.sum(input_lens)}') + print(f'#Output tokens: {np.sum(output_lens)}') + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float('inf'): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer.encode(outputs[i].generated_text, + add_special_tokens=False)) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append( + (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + 'All requests failed. This is likely due to a misconfiguration ' + 'on the benchmark arguments.', + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f'Unknown backend: {backend}') + + print('Starting initial single prompt test run...') + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + 'Initial test run failed - Please make sure benchmark arguments ' + f'are correctly specified. Error: {test_output.error}') + else: + print('Initial test run completed. Starting main benchmark run...') + + time.sleep(1.5) + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print('\n{s:{c}^{n}}'.format(s=' Serving Benchmark Result ', n=50, c='=')) + print('{:<40} {:<10}'.format('Backend:', backend)) + print('{:<40} {:<10}'.format('Traffic request rate:', request_rate)) + print('{:<40} {:<10}'.format('Successful requests:', metrics.completed)) + print('{:<40} {:<10.2f}'.format('Benchmark duration (s):', + benchmark_duration)) + print('{:<40} {:<10}'.format('Total input tokens:', metrics.total_input)) + print('{:<40} {:<10}'.format('Total generated tokens:', + metrics.total_output)) + print('{:<40} {:<10}'.format('Total generated tokens (retokenized):', + metrics.total_output_retokenized)) + print('{:<40} {:<10.2f}'.format('Request throughput (req/s):', + metrics.request_throughput)) + print('{:<40} {:<10.2f}'.format('Input token throughput (tok/s):', + metrics.input_throughput)) + print('{:<40} {:<10.2f}'.format('Output token throughput (tok/s):', + metrics.output_throughput)) + print('{s:{c}^{n}}'.format(s='End-to-End Latency', n=50, c='-')) + print('{:<40} {:<10.2f}'.format('Mean E2E Latency (ms):', + metrics.mean_e2e_latency_ms)) + print('{:<40} {:<10.2f}'.format('Median E2E Latency (ms):', + metrics.median_e2e_latency_ms)) + print('{s:{c}^{n}}'.format(s='Time to First Token', n=50, c='-')) + print('{:<40} {:<10.2f}'.format('Mean TTFT (ms):', metrics.mean_ttft_ms)) + print('{:<40} {:<10.2f}'.format('Median TTFT (ms):', + metrics.median_ttft_ms)) + print('{:<40} {:<10.2f}'.format('P99 TTFT (ms):', metrics.p99_ttft_ms)) + print('{s:{c}^{n}}'.format(s='Time per Output Token (excl. 1st token)', + n=50, + c='-')) + print('{:<40} {:<10.2f}'.format('Mean TPOT (ms):', metrics.mean_tpot_ms)) + print('{:<40} {:<10.2f}'.format('Median TPOT (ms):', + metrics.median_tpot_ms)) + print('{:<40} {:<10.2f}'.format('P99 TPOT (ms):', metrics.p99_tpot_ms)) + print('{s:{c}^{n}}'.format(s='Inter-token Latency', n=50, c='-')) + print('{:<40} {:<10.2f}'.format('Mean ITL (ms):', metrics.mean_itl_ms)) + print('{:<40} {:<10.2f}'.format('Median ITL (ms):', metrics.median_itl_ms)) + print('{:<40} {:<10.2f}'.format('P99 ITL (ms):', metrics.p99_itl_ms)) + print('=' * 50) + + if (metrics.median_ttft_ms is not None and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None): + result = { + 'backend': args.backend, + 'dataset_name': args.dataset_name, + 'request_rate': request_rate, + 'total_input_tokens': metrics.total_input, + 'total_output_tokens': metrics.total_output, + 'total_output_tokens_retokenized': + metrics.total_output_retokenized, + 'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms, + 'median_e2e_latency_ms': metrics.median_e2e_latency_ms, + 'median_ttft_ms': metrics.median_ttft_ms, + 'median_itl_ms': metrics.median_itl_ms, + 'output_throughput': metrics.output_throughput, + 'sharegpt_output_len': args.sharegpt_output_len, + 'random_input_len': args.random_input_len, + 'random_output_len': args.random_output_len, + 'random_range_ratio': args.random_range_ratio, + 'duration': benchmark_duration, + 'completed': metrics.completed, + } + else: + print(f'Error running benchmark for request rate: {request_rate}') + print('-' * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime('%m%d') + if args.dataset_name == 'random': + output_file_name = f'{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl' # noqa + else: + output_file_name = f'{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl' # noqa + + # Append results to a JSONL file + with open(output_file_name, 'a') as file: + file.write(json.dumps(result) + '\n') + + result = { + 'duration': benchmark_duration, + 'completed': metrics.completed, + 'total_input_tokens': metrics.total_input, + 'total_output_tokens': metrics.total_output, + 'total_output_tokens_retokenized': metrics.total_output_retokenized, + 'request_throughput': metrics.request_throughput, + 'input_throughput': metrics.input_throughput, + 'output_throughput': metrics.output_throughput, + 'mean_ttft_ms': metrics.mean_ttft_ms, + 'median_ttft_ms': metrics.median_ttft_ms, + 'std_ttft_ms': metrics.std_ttft_ms, + 'p99_ttft_ms': metrics.p99_ttft_ms, + 'mean_tpot_ms': metrics.mean_tpot_ms, + 'median_tpot_ms': metrics.median_tpot_ms, + 'std_tpot_ms': metrics.std_tpot_ms, + 'p99_tpot_ms': metrics.p99_tpot_ms, + 'mean_itl_ms': metrics.mean_itl_ms, + 'median_itl_ms': metrics.median_itl_ms, + 'std_itl_ms': metrics.std_itl_ms, + 'p99_itl_ms': metrics.p99_itl_ms, + 'input_lens': [output.prompt_len for output in outputs], + 'output_lens': output_lens, + 'ttfts': [output.ttft for output in outputs], + 'itls': [output.itl for output in outputs], + 'generated_texts': [output.generated_text for output in outputs], + 'errors': [output.error for output in outputs], + 'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms, + 'median_e2e_latency_ms': metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(',')) == 3: + start, stop, step = map(int, request_rate_range.split(',')) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(','))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, + trust_remote_code=True) + return 'chat_template' in tokenizer.init_kwargs + except Exception as e: + print(f'Fail to load tokenizer config with error={e}') + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + # Set global environments + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + # Set url + if args.port is None: + args.port = { + 'sglang': 30000, + 'sglang-native': 30000, + 'sglang-oai': 30000, + 'lmdeploy': 23333, + 'vllm': 8000, + 'trt': 8000, + 'gserver': 9988, + }.get(args.backend, 30000) + + model_url = (f'{args.base_url}/v1/models' if args.base_url else + f'http://{args.host}:{args.port}/v1/models') + + if args.backend in ['sglang', 'sglang-native']: + api_url = (f'{args.base_url}/generate' if args.base_url else + f'http://{args.host}:{args.port}/generate') + elif args.backend in ['sglang-oai', 'vllm', 'lmdeploy']: + api_url = (f'{args.base_url}/v1/completions' if args.base_url else + f'http://{args.host}:{args.port}/v1/completions') + elif args.backend == 'trt': + api_url = ( + f'{args.base_url}/v2/models/ensemble/generate_stream' + if args.base_url else + f'http://{args.host}:{args.port}/v2/models/ensemble/generate_stream' # noqa + ) + if args.model is None: + print('Please provide a model using `--model` when using ' + '`trt` backend.') + sys.exit(1) + elif args.backend == 'gserver': + api_url = args.base_url if args.base_url else \ + f'{args.host}:{args.port}' + args.model = args.model or 'default' + + # Get model name + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get('data', []) + args.model = model_list[0]['id'] if model_list else None + except Exception as e: + print(f'Failed to fetch model from {model_url}. Error: {e}') + print('Please specify the correct host and port using ' + '`--host` and `--port`.') + sys.exit(1) + + if args.model is None: + print('No model specified or found. Please provide a model ' + 'using `--model`.') + sys.exit(1) + + if not check_chat_template(args.model): + print('\nWARNING It is recommended to use the `Chat` or `Instruct` ' + 'model for benchmarking.\n' + 'Because when the tokenizer counts the output tokens, if ' + 'there is gibberish, it might count incorrectly.\n') + + print(f'{args}\n') + + # Read dataset + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == 'sharegpt': + assert args.random_input_len is None and args.random_output_len is None + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == 'random': + assert args.random_input_len is not None and \ + args.random_output_len is not None + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f'Unknown dataset: {args.dataset_name}') + + if not args.multi: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + extra_request_body=extra_request_body, + )) + else: + # Benchmark multiple rps. + # TODO: use a fixed duration to compute num_prompts + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + extra_request_body=extra_request_body, + )) + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, + (target_soft_limit, current_hard)) + except ValueError as e: + print(f'Fail to set RLIMIT_NOFILE: {e}') + + +if __name__ == '__main__': + parser = ArgumentParser( + description='Benchmark the online serving throughput.') + parser.add_argument( + '--backend', + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default='sglang', + help='Must specify a backend, depending on the LLM Inference Engine.', + ) + parser.add_argument( + '--base-url', + type=str, + default=None, + help='Server or API base url if not using http host and port.', + ) + parser.add_argument('--host', + type=str, + default='0.0.0.0', + help='Default host is 0.0.0.0.') + parser.add_argument( + '--port', + type=int, + help='If not set, the default port is configured according to its ' + 'default value for different LLM Inference Engines.', + ) + parser.add_argument( + '--dataset-name', + type=str, + default='sharegpt', + choices=['sharegpt', 'random'], + help='Name of the dataset to benchmark on.', + ) + parser.add_argument('--dataset-path', + type=str, + default='', + help='Path to the dataset.') + parser.add_argument( + '--model', + type=str, + help='Name or path of the model. If not set, the default model will ' + 'request /v1/models for conf.', + ) + parser.add_argument( + '--tokenizer', + type=str, + help='Name or path of the tokenizer. If not set, using the model ' + 'conf.', + ) + parser.add_argument( + '--num-prompts', + type=int, + default=1000, + help='Number of prompts to process. Default is 1000.', + ) + parser.add_argument( + '--sharegpt-output-len', + type=int, + default=None, + help='Output length for each request. Overrides the output length ' + 'from the ShareGPT dataset.', + ) + parser.add_argument( + '--random-input-len', + type=int, + help='Number of input tokens per request, used only for random ' + 'dataset.', + ) + parser.add_argument( + '--random-output-len', + type=int, + help='Number of output tokens per request, used only for random ' + 'dataset.', + ) + parser.add_argument( + '--random-range-ratio', + type=float, + default=0.0, + help='Range of sampled ratio of input/output length, ' + 'used only for random dataset.', + ) + parser.add_argument( + '--request-rate', + type=float, + default=float('inf'), + help='Number of requests per second. If this is inf, then all the ' + 'requests are sent at time 0. Otherwise, we use Poisson process to ' + 'synthesize the request arrival times. Default is inf.', + ) + parser.add_argument('--seed', type=int, default=1, help='The random seed.') + parser.add_argument( + '--multi', + action='store_true', + help='Use request rate range rather than single value.', + ) + parser.add_argument( + '--request-rate-range', + type=str, + default='2,34,2', + help='Range of request rates in the format start,stop,step. Default ' + 'is 2,34,2. It also supports a list of request rates, requiring ' + 'the parameters to not equal three.', + ) + parser.add_argument('--output-file', + type=str, + help='Output JSONL file name.') + parser.add_argument( + '--disable-tqdm', + action='store_true', + help='Specify to disable tqdm progress bar.', + ) + parser.add_argument( + '--disable-stream', + action='store_true', + help='Disable streaming mode.', + ) + parser.add_argument( + '--disable-ignore-eos', + action='store_true', + help='Disable ignoring EOS.', + ) + parser.add_argument( + '--extra-request-body', + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help='Append given JSON object to the request payload. You can use ' + 'this to specify additional generate params like sampling params.', + ) + args = parser.parse_args() + run_benchmark(args) From f28cf4ede1ab7402476daa837ae56ed21f8eca9c Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Fri, 18 Oct 2024 15:59:38 +0800 Subject: [PATCH 2/3] rollback filename --- benchmark/{profile_serving.py => profile_restful_api.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmark/{profile_serving.py => profile_restful_api.py} (100%) diff --git a/benchmark/profile_serving.py b/benchmark/profile_restful_api.py similarity index 100% rename from benchmark/profile_serving.py rename to benchmark/profile_restful_api.py From b919e1d74732c5fa907af027c4c923ca8cb6548b Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Fri, 18 Oct 2024 20:00:41 +0800 Subject: [PATCH 3/3] update --- .github/workflows/stable.yml | 12 +++--- autotest/utils/benchmark_utils.py | 2 +- benchmark/README.md | 20 +--------- docs/en/advance/debug_turbomind.md | 2 +- docs/en/benchmark/profile_api_server.md | 43 ++-------------------- docs/zh_cn/advance/debug_turbomind.md | 2 +- docs/zh_cn/benchmark/profile_api_server.md | 42 ++------------------- 7 files changed, 18 insertions(+), 105 deletions(-) diff --git a/.github/workflows/stable.yml b/.github/workflows/stable.yml index 85daed8e2b..98faf2ffa4 100644 --- a/.github/workflows/stable.yml +++ b/.github/workflows/stable.yml @@ -143,12 +143,12 @@ jobs: opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-3 - name: Test lmdeploy - restful api run: | - python3 benchmark/profile_restful_api.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json --stream-output True --num-prompts 10000 --csv ${{env.REPORT_DIR}}/stable.csv > ${{env.REPORT_DIR}}/stable.log - python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-1.csv > ${{env.REPORT_DIR}}/stable-internal-1.log - python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-2.log - python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-3.log - python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-4.log - python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-5.log + python3 benchmark/profile_restful_api.py --port 23344 --dataset-path /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10000 > ${{env.REPORT_DIR}}/stable.log + python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --num-prompts 100000 > ${{env.REPORT_DIR}}/stable-internal-1.log + python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --num-prompts 100000 > ${{env.REPORT_DIR}}/stable-internal-2.log + python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --num-prompts 100000 > ${{env.REPORT_DIR}}/stable-internal-3.log + python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --num-prompts 100000 > ${{env.REPORT_DIR}}/stable-internal-4.log + python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 > ${{env.REPORT_DIR}}/stable-internal-5.log - name: Attach result if: always() run: | diff --git a/autotest/utils/benchmark_utils.py b/autotest/utils/benchmark_utils.py index 3da375ccb5..24eb6c8f1c 100644 --- a/autotest/utils/benchmark_utils.py +++ b/autotest/utils/benchmark_utils.py @@ -168,7 +168,7 @@ def restful_test(config, if not health_check(http_url): return False, 'server not start' - command = f'python3 benchmark/profile_restful_api.py localhost:{port} {model_path} {dataset_path} --stream-output True ' # noqa: F401, E501 + command = f'python3 benchmark/profile_restful_api.py --port {port} --tokenizer {model_path} --dataset-path {dataset_path}' # noqa: F401, E501 if is_smoke: command += ' --num-prompts 200' else: diff --git a/benchmark/README.md b/benchmark/README.md index 057d38bb11..9e56768640 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -33,20 +33,6 @@ python profile_generation.py \ --concurrency 1 8 --prompt-tokens 1 512 --completion-tokens 2048 512 ``` -## profile serving - -Tools above profile models with Python API. `profile_serving.py` is used to do benchmark on serving. - -```bash -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - -python profile_serving.py \ - ${TritonServerAddress} \ - /path/to/tokenizer \ # ends with .model for most models. Otherwise, please pass model_path/triton_models/tokenizer. - ShareGPT_V3_unfiltered_cleaned_split.json \ - --concurrency 64 -``` - ## profile restful api `profile_restful_api.py` is used to do benchmark on api server. @@ -54,9 +40,5 @@ python profile_serving.py \ ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -python profile_restful_api.py \ - ${ServerAddress} \ - /path/to/tokenizer \ # ends with .model for most models. Otherwise, please pass model_path/triton_models/tokenizer. - ShareGPT_V3_unfiltered_cleaned_split.json \ - --concurrency 64 +python3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json ``` diff --git a/docs/en/advance/debug_turbomind.md b/docs/en/advance/debug_turbomind.md index 91733ce2a5..d38b548a95 100644 --- a/docs/en/advance/debug_turbomind.md +++ b/docs/en/advance/debug_turbomind.md @@ -129,7 +129,7 @@ Reading symbols from python3... # (Optional) Use https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py to send a request -python3 profile_restful_api.py --server_addr 127.0.0.1:23333 --tokenizer_path /workdir/Llama-2-13b-chat-hf --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --concurrency 1 --num_prompts 1 +python3 profile_restful_api.py --backend lmdeploy --dataset-path /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --num_prompts 1 ```` ## Using GDB diff --git a/docs/en/benchmark/profile_api_server.md b/docs/en/benchmark/profile_api_server.md index 07dfc49007..c8b626af36 100644 --- a/docs/en/benchmark/profile_api_server.md +++ b/docs/en/benchmark/profile_api_server.md @@ -1,52 +1,17 @@ # Profile API Server -The way to profiling `api_server` performance is similar to the method for [profiling throughput](./profile_throughput.md). The difference is `api_server` should be launched successfully before testing. - -The profiling script is `profile_restful_api.py`. Before running it, please install the lmdeploy precompiled package, download the script and the test dataset: +Before benchmarking the api_server, please install the lmdeploy precompiled package and download the script and the test dataset: ```shell pip install lmdeploy git clone --depth=1 https://github.com/InternLM/lmdeploy cd lmdeploy/benchmark -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -``` - -## Metrics - -LMDeploy records the performance metrics like first token latency, token throughput (tokens/s) and request throughput (RPM) - -`first_token_latency` is only reported in the case of streaming inference. - -The formula for calculating `token throughput` is: - -$$ -TokenThroughput = Number\\ of\\ generated\\ tokens/TotalTime -$$ - -And the formula for calculating `request throughput` is: - -$$ -RPM(request\\ per\\ minute)=Number\\ of\\ prompts/TotalTime * 60 -$$ - -Total time includes prefill time. - -## Profile - -In this section, we take [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) as an example to show the benchmark procedure. - -### Launch api_server - -```shell -lmdeploy serve api_server internlm/internlm-7b ``` -If you would like to change the server's port or other parameters, such as inference engine, max batch size and etc., please run `lmdeploy serve api_server -h` or read [this](../llm/api_server.md) guide to get the detailed explanation. - -### Profile +Launch the server first (you may refer [here](../llm/api_server.md) for guide) and run the following command: ```shell -python3 profile_restful_api.py http://0.0.0.0:23333 internlm/internlm-7b ./ShareGPT_V3_unfiltered_cleaned_split.json +python3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` -For detailed argument specification of `profile_restful_api.py`, such as request concurrency, sampling parameters an so on, please run the help command `python3 profile_restful_api.py -h`. +For detailed argument specification of `profile_restful_api.py`, please run the help command `python3 benchmark/profile_restful_api.py -h`. diff --git a/docs/zh_cn/advance/debug_turbomind.md b/docs/zh_cn/advance/debug_turbomind.md index 3c3b75421d..7c00e2f9d6 100644 --- a/docs/zh_cn/advance/debug_turbomind.md +++ b/docs/zh_cn/advance/debug_turbomind.md @@ -129,7 +129,7 @@ Reading symbols from python3... # (可选) 使用 https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py 发送请求 -python3 profile_restful_api.py --server_addr 127.0.0.1:23333 --tokenizer_path /workdir/Llama-2-13b-chat-hf --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --concurrency 1 --num_prompts 1 +python3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json ```` ## 使用 GDB diff --git a/docs/zh_cn/benchmark/profile_api_server.md b/docs/zh_cn/benchmark/profile_api_server.md index c872820040..687b22ad19 100644 --- a/docs/zh_cn/benchmark/profile_api_server.md +++ b/docs/zh_cn/benchmark/profile_api_server.md @@ -1,8 +1,6 @@ # api_server 性能测试 -api_server 的测试方式与[求吞吐量测试方法](./profile_throughput.md)类似。不同的是,在测试前,需要先启动 api_server,然后再通过测试脚本发送请求进行测试。 - -测试脚本是 `profile_restful_api.py`。测试之前,请安装 lmdeploy 预编译包,并下载评测脚本和测试数据集。 +测试之前,请安装 lmdeploy 预编译包,并下载测试脚本和数据。 ```shell pip install lmdeploy @@ -11,42 +9,10 @@ cd lmdeploy/benchmark wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` -## 测量指标 - -LMDeploy 统计首token延时(first_token_latency)、token吞吐量(tokens/s)和请求吞吐量(RPM)。 - -`first_token_latency` 只有在流式推理的情况下才会输出。 - -token吞吐量的计算公式为: - -$$ -吞吐量 = 生成的token数量 / 总时间 -$$ - -请求吞吐量的计算公式为: - -$$ -吞吐量 = 请求数量 / 总时间 -$$ - -总时间包括 prefill 时间 - -## 测量方法 - -我们以 [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) 为例,展示 api_server 的性能测试流程 - -### 启动服务 - -```shell -lmdeploy serve api_server internlm/internlm-7b -``` - -如果你想改变 server 的端口,或者诸如推理引擎、最大批处理值等参数,请运行 `lmdeploy serve api_server -h` 或者阅读[这篇文档](../llm/api_server.md),查看详细的参数说明。 - -### 测速 +然后,启动模型服务(可以参考[这里](../llm/api_server.md))。接着,使用下面的命令: ```shell -python3 profile_restful_api.py http://0.0.0.0:23333 internlm/internlm-7b ./ShareGPT_V3_unfiltered_cleaned_split.json +python3 profile_restful_api.py --backend lmdeploy --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` -关于 `profile_restful_api.py` 脚本中的参数,比如请求并发数、采样参数等等,可以通过运行命令 `python3 profile_restful_api.py -h` 查阅。 +关于 `profile_restful_api.py`的帮助信息,可以通过`python3 profile_restful_api.py -h`查阅