From 521dcfa2299e5f1494ab5cb3fc936baf9a76d800 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 26 Mar 2023 12:41:47 +0800 Subject: [PATCH 01/13] modify current frontend & server to prepare for real frontend --- cacheflow/master/frontend.py | 26 ++------ cacheflow/master/scheduler.py | 22 ++++--- cacheflow/sampling_params.py | 13 ++++ cacheflow/utils.py | 1 + server.py | 114 ++++++---------------------------- 5 files changed, 52 insertions(+), 124 deletions(-) diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py index cfa17684fd56..cb188387db2c 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/frontend.py @@ -22,30 +22,16 @@ def __init__( self.seq_counter = Counter() self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = [] + def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams: + # Stop generation when we see an EOS token. + sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id) + return sampling_params + def query( self, prompt: str, - n: int = 1, - temperature: float = 1.0, - top_p: float = 1.0, - use_beam_search: bool = False, - stop_token_ids: Set[int] = set(), - max_num_steps: int = 16, # From OpenAI API. - num_logprobs: int = 0, - context_window_size: Optional[int] = None, + sampling_params: SamplingParams, ) -> None: - # Stop when we see an EOS token. - stop_token_ids.add(self.tokenizer.eos_token_id) - sampling_params = SamplingParams( - n=n, - temperature=temperature, - top_p=top_p, - use_beam_search=use_beam_search, - stop_token_ids=stop_token_ids, - max_num_steps=max_num_steps, - num_logprobs=num_logprobs, - context_window_size=context_window_size, - ) token_ids = self.tokenizer.encode(prompt) self._add_query(token_ids, sampling_params) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 3931d92684f3..cd7a34d7584d 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,7 +1,6 @@ from typing import Dict, List from cacheflow.master.block_manager import BlockSpaceManager -from cacheflow.master.frontend import Frontend from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup @@ -14,14 +13,12 @@ class Scheduler: def __init__( self, - frontend: Frontend, controllers: List, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, max_num_batched_tokens: int, ) -> None: - self.frontend = frontend self.controllers = controllers self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks @@ -46,10 +43,15 @@ def __init__( self.swapped: List[SequenceGroup] = [] # Pending sequence groups (FIFO). self.pending: List[SequenceGroup] = [] + # Finished sequence groups. + self.finished: List[SequenceGroup] = [] - def _fetch_inputs(self) -> None: - inputs = self.frontend.get_inputs() - for seq_group, sampling_params in inputs: + def add_sequence_groups( + self, + sequence_groups: List[SequenceGroup, SamplingParams], + ) -> None: + # Add sequence groups to the pending queue. + for seq_group, sampling_params in sequence_groups: self.pending.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params @@ -158,7 +160,6 @@ def step(self) -> None: # 3. Join new sequences if possible. # NOTE: Here we implicitly assume FCFS scheduling. # TODO(woosuk): Add a batching policy to control the batch size. - self._fetch_inputs() if not self.swapped: for i, seq_group in enumerate(self.pending): num_prompt_tokens = seq_group.seqs[0].get_len() @@ -277,4 +278,9 @@ def _return(self, seq_group: SequenceGroup) -> None: group_id = seq_group.group_id del self.num_steps[group_id] del self.sampling_params[group_id] - self.frontend.print_response(seq_group) + self.finished.append(seq_group) + + def get_finished(self) -> List[SequenceGroup]: + finished = self.finished + self.finished = [] + return finished diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 5f446198bd67..e7f118bba951 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -69,3 +69,16 @@ def __repr__(self) -> str: f'max_num_steps={self.max_num_steps}, ' f'num_logprobs={self.num_logprobs}, ' f'context_window_size={self.context_window_size})') + + @classmethod + def from_dict(cls, d: dict) -> 'SamplingParams': + return cls( + n=d.get('n', 1), + temperature=d.get('temperature', 1.0), + top_p=d.get('top_p', 1.0), + use_beam_search=d.get('use_beam_search', False), + stop_token_ids=set(d.get('stop_token_ids', set())), + max_num_steps=d.get('max_num_steps', 16), + num_logprobs=d.get('num_logprobs', 0), + context_window_size=d.get('context_window_size', None), + ) diff --git a/cacheflow/utils.py b/cacheflow/utils.py index db8eb8aaba4c..e987449fd9dd 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -26,6 +26,7 @@ def __next__(self) -> int: def reset(self) -> None: self.counter = 0 + def set_random_seed(seed: int): random.seed(seed) np.random.seed(seed) diff --git a/server.py b/server.py index 5838f439f53e..e200f9c6bc13 100644 --- a/server.py +++ b/server.py @@ -1,80 +1,13 @@ import argparse -import random -from typing import List, Tuple, Dict - -import ray +from typing import List from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler +from cacheflow.master.server_utils import (initialize_ray_cluster, + add_server_arguments) from cacheflow.models import get_memory_analyzer -from cacheflow.worker.controller import Controller, DeviceID - - -def initialize_ray_cluster( - address: str = 'auto', - pipeline_parallel_size: int = 1, - tensor_parallel_size: int = 1, -) -> Tuple[int, int, str, List[List[DeviceID]]]: - # Connect to a ray cluster. - ray.init(address=address) - - # Assume we have a uniform cluster that each node has the same number of - # GPUs for now. - valid_node_resources = [] - num_devices_per_node = None - for node in ray.nodes(): - if (not node['Alive']) or node['Resources']['GPU'] <= 0: - continue - if num_devices_per_node is None: - num_devices_per_node = node['Resources']['GPU'] - else: - assert num_devices_per_node == node['Resources']['GPU'], ( - "The number of GPUs per node is not uniform.") - for key in node['Resources']: - if key.startswith('node:'): - valid_node_resources.append(key) - - num_nodes = len(valid_node_resources) - - assert (pipeline_parallel_size * tensor_parallel_size - <= num_nodes * num_devices_per_node), ( - "The number of required GPUs exceeds the total number of " - "available GPUs.") - if tensor_parallel_size >= num_devices_per_node: - assert tensor_parallel_size % num_devices_per_node == 0, ( - "The number of tensor parallelism is not divisible by the " - "number of GPUs per node.") - else: - assert num_devices_per_node % tensor_parallel_size == 0, ( - "The number of GPUs per node is not divisible by the number " - "of tensor parallelism.") - - # Assign GPUs to pipeline stages. - rank = 0 - current_node_id = 0 - current_device_id = 0 - distributed_init_method = None - all_stage_devices = [] - - for i in range(pipeline_parallel_size): - stage_devices = [] - for j in range(tensor_parallel_size): - node_resource = valid_node_resources[current_node_id] - stage_devices.append((rank, node_resource, current_device_id)) - if distributed_init_method is None: - ip = node_resource.split("node:")[-1] - port = random.randint(10000, 20000) - distributed_init_method = f"tcp://{ip}:{port}" - rank += 1 - current_device_id += 1 - if current_device_id >= num_devices_per_node: - current_node_id += 1 - current_device_id = 0 - all_stage_devices.append(stage_devices) - - return (num_nodes, num_devices_per_node, distributed_init_method, - all_stage_devices) - +from cacheflow.worker.controller import Controller +from cacheflow.sampling_params import SamplingParams def main(args: argparse.Namespace): # TODO(zhuohan): Support pipeline parallelism. @@ -121,15 +54,8 @@ def main(args: argparse.Namespace): ) controllers.append(controller) - # Create a frontend. - frontend = Frontend( - model_name=args.model, - block_size=args.block_size, - ) - # Create a scheduler. scheduler = Scheduler( - frontend=frontend, controllers=controllers, block_size=args.block_size, num_gpu_blocks=num_gpu_blocks, @@ -141,6 +67,12 @@ def main(args: argparse.Namespace): controllers[i].set_next(controllers[i + 1]) controllers[-1].set_next(scheduler) + # Create a frontend. + frontend = Frontend( + model_name=args.model, + block_size=args.block_size, + ) + # Test the following inputs. test_inputs = [ ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), @@ -149,30 +81,20 @@ def main(args: argparse.Namespace): ] while True: if test_inputs: - text, sampling_params = test_inputs.pop(0) + text, sampling_params_dict = test_inputs.pop(0) + sampling_params = SamplingParams.from_dict(sampling_params_dict) + sampling_params = frontend.add_eos_token(sampling_params) frontend.query(text, **sampling_params) + scheduler.add_sequence_groups(frontend.get_inputs()) scheduler.step() + for seq_group in scheduler.get_finished(): + frontend.print_response(seq_group) if not (scheduler.pending or scheduler.running or test_inputs): break if __name__ == '__main__': parser = argparse.ArgumentParser(description='CacheFlow server') - # Model arguments - parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') - parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', - help='model path to download and load the weights') - # Parallel arguments - parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas') - # KV cache arguments - parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') - # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. - parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') - # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=0, help='random seed') - parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') - parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') + parser = add_server_arguments(parser) args = parser.parse_args() - main(args) From 3e8c1320bec89533de9aa272deda47a1209b54eb Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 26 Mar 2023 12:45:43 +0800 Subject: [PATCH 02/13] add server_utils --- cacheflow/master/server_utils.py | 91 ++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 cacheflow/master/server_utils.py diff --git a/cacheflow/master/server_utils.py b/cacheflow/master/server_utils.py new file mode 100644 index 000000000000..0e684abaabd9 --- /dev/null +++ b/cacheflow/master/server_utils.py @@ -0,0 +1,91 @@ +from typing import List, Tuple +import random + +import ray + +from cacheflow.worker.controller import DeviceID + + +def initialize_ray_cluster( + address: str = 'auto', + pipeline_parallel_size: int = 1, + tensor_parallel_size: int = 1, +) -> Tuple[int, int, str, List[List[DeviceID]]]: + # Connect to a ray cluster. + ray.init(address=address) + + # Assume we have a uniform cluster that each node has the same number of + # GPUs for now. + valid_node_resources = [] + num_devices_per_node = None + for node in ray.nodes(): + if (not node['Alive']) or node['Resources']['GPU'] <= 0: + continue + if num_devices_per_node is None: + num_devices_per_node = node['Resources']['GPU'] + else: + assert num_devices_per_node == node['Resources']['GPU'], ( + "The number of GPUs per node is not uniform.") + for key in node['Resources']: + if key.startswith('node:'): + valid_node_resources.append(key) + + num_nodes = len(valid_node_resources) + + assert (pipeline_parallel_size * tensor_parallel_size + <= num_nodes * num_devices_per_node), ( + "The number of required GPUs exceeds the total number of " + "available GPUs.") + if tensor_parallel_size >= num_devices_per_node: + assert tensor_parallel_size % num_devices_per_node == 0, ( + "The number of tensor parallelism is not divisible by the " + "number of GPUs per node.") + else: + assert num_devices_per_node % tensor_parallel_size == 0, ( + "The number of GPUs per node is not divisible by the number " + "of tensor parallelism.") + + # Assign GPUs to pipeline stages. + rank = 0 + current_node_id = 0 + current_device_id = 0 + distributed_init_method = None + all_stage_devices = [] + + for i in range(pipeline_parallel_size): + stage_devices = [] + for j in range(tensor_parallel_size): + node_resource = valid_node_resources[current_node_id] + stage_devices.append((rank, node_resource, current_device_id)) + if distributed_init_method is None: + ip = node_resource.split("node:")[-1] + port = random.randint(10000, 20000) + distributed_init_method = f"tcp://{ip}:{port}" + rank += 1 + current_device_id += 1 + if current_device_id >= num_devices_per_node: + current_node_id += 1 + current_device_id = 0 + all_stage_devices.append(stage_devices) + + return (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) + + +def add_server_arguments(parser): + # Model arguments + parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') + parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', + help='model path to download and load the weights') + # Parallel arguments + parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas') + # KV cache arguments + parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') + # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. + parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') + parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') + return parser \ No newline at end of file From 8873415eb655db8a53d06970d93b1fb2bbfcd4f3 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 26 Mar 2023 04:47:43 +0000 Subject: [PATCH 03/13] fix small bugs --- cacheflow/master/scheduler.py | 4 ++-- server.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index cd7a34d7584d..1ade007b5f22 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Tuple from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.sampling_params import SamplingParams @@ -48,7 +48,7 @@ def __init__( def add_sequence_groups( self, - sequence_groups: List[SequenceGroup, SamplingParams], + sequence_groups: List[Tuple[SequenceGroup, SamplingParams]], ) -> None: # Add sequence groups to the pending queue. for seq_group, sampling_params in sequence_groups: diff --git a/server.py b/server.py index e200f9c6bc13..e4df9ee470ff 100644 --- a/server.py +++ b/server.py @@ -84,7 +84,7 @@ def main(args: argparse.Namespace): text, sampling_params_dict = test_inputs.pop(0) sampling_params = SamplingParams.from_dict(sampling_params_dict) sampling_params = frontend.add_eos_token(sampling_params) - frontend.query(text, **sampling_params) + frontend.query(text, sampling_params) scheduler.add_sequence_groups(frontend.get_inputs()) scheduler.step() for seq_group in scheduler.get_finished(): From 71c2b934b2fe9311ec077e6f540f8cc30938941e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 26 Mar 2023 22:03:50 +0800 Subject: [PATCH 04/13] add a new server class --- .../master/{server_utils.py => server.py} | 79 ++++++++++++++++++- server.py | 77 ++++-------------- 2 files changed, 92 insertions(+), 64 deletions(-) rename cacheflow/master/{server_utils.py => server.py} (56%) diff --git a/cacheflow/master/server_utils.py b/cacheflow/master/server.py similarity index 56% rename from cacheflow/master/server_utils.py rename to cacheflow/master/server.py index 0e684abaabd9..dc31f2bc55c4 100644 --- a/cacheflow/master/server_utils.py +++ b/cacheflow/master/server.py @@ -1,9 +1,84 @@ +import argparse from typing import List, Tuple import random import ray -from cacheflow.worker.controller import DeviceID +from cacheflow.master.scheduler import Scheduler +from cacheflow.models import get_memory_analyzer +from cacheflow.worker.controller import Controller, DeviceID + + +class Server: + def __init__( + self, + model: str, + model_path: str, + pipeline_parallel_size: int, + tensor_parallel_size: int, + block_size: int, + dtype: str, + seed: int, + swap_space: int, + max_batch_size: int, + ): + # TODO(zhuohan): Support pipeline parallelism. + assert pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (self.num_nodes, self.num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=tensor_parallel_size)) + + self.world_size = pipeline_parallel_size * tensor_parallel_size + + self.memory_analyzer = get_memory_analyzer( + model_name=model, + block_size=block_size, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + ) + self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks( + max_num_batched_tokens=max_batch_size) + self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks( + swap_space=swap_space) + print(f'# GPU blocks: {self.num_gpu_blocks}, ' + f'# CPU blocks: {self.num_cpu_blocks}') + + # Create a controller for each pipeline stage. + self.controllers: List[Controller] = [] + for i in range(pipeline_parallel_size): + controller = Controller( + stage_id=i, + stage_devices=all_stage_devices[i], + world_size=self.world_size, + pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=tensor_parallel_size, + distributed_init_method=distributed_init_method, + model_name=model, + block_size=block_size, + num_gpu_blocks=self.num_gpu_blocks, + num_cpu_blocks=self.num_cpu_blocks, + dtype=dtype, + seed=seed, + model_path=model_path, + ) + self.controllers.append(controller) + + # Create a scheduler. + self.scheduler = Scheduler( + controllers=self.controllers, + block_size=block_size, + num_gpu_blocks=self.num_gpu_blocks, + num_cpu_blocks=self.num_cpu_blocks, + max_num_batched_tokens=max_batch_size, + ) + # Connect the controllers. + for i in range(len(self.controllers) - 1): + self.controllers[i].set_next(self.controllers[i + 1]) + self.controllers[-1].set_next(self.scheduler) def initialize_ray_cluster( @@ -72,7 +147,7 @@ def initialize_ray_cluster( all_stage_devices) -def add_server_arguments(parser): +def add_server_arguments(parser: argparse.ArgumentParser): # Model arguments parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', diff --git a/server.py b/server.py index e4df9ee470ff..9384efc59524 100644 --- a/server.py +++ b/server.py @@ -2,70 +2,22 @@ from typing import List from cacheflow.master.frontend import Frontend -from cacheflow.master.scheduler import Scheduler -from cacheflow.master.server_utils import (initialize_ray_cluster, - add_server_arguments) -from cacheflow.models import get_memory_analyzer -from cacheflow.worker.controller import Controller +from cacheflow.master.server import Server, add_server_arguments from cacheflow.sampling_params import SamplingParams def main(args: argparse.Namespace): - # TODO(zhuohan): Support pipeline parallelism. - assert args.pipeline_parallel_size == 1, ( - 'Pipeline parallelism is not supported yet.') - - (num_nodes, num_devices_per_node, distributed_init_method, - all_stage_devices) = ( - initialize_ray_cluster( - pipeline_parallel_size=args.pipeline_parallel_size, - tensor_parallel_size=args.tensor_parallel_size)) - - world_size = args.pipeline_parallel_size * args.tensor_parallel_size - - memory_analyzer = get_memory_analyzer( - model_name=args.model, - block_size=args.block_size, - dtype=args.dtype, + # Create a server. + server = Server( + model=args.model, + model_path=args.model_path, + pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size, - ) - num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks( - max_num_batched_tokens=args.max_batch_size) - num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks( - swap_space=args.swap_space) - print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') - - # Create a controller for each pipeline stage. - controllers: List[Controller] = [] - for i in range(args.pipeline_parallel_size): - controller = Controller( - stage_id=i, - stage_devices=all_stage_devices[i], - world_size=world_size, - pipeline_parallel_size=args.pipeline_parallel_size, - tensor_parallel_size=args.tensor_parallel_size, - distributed_init_method=distributed_init_method, - model_name=args.model, - block_size=args.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - dtype=args.dtype, - seed=args.seed, - model_path=args.model_path, - ) - controllers.append(controller) - - # Create a scheduler. - scheduler = Scheduler( - controllers=controllers, block_size=args.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - max_num_batched_tokens=args.max_batch_size, + dtype=args.dtype, + seed=args.seed, + swap_space=args.swap_space, + max_batch_size=args.max_batch_size, ) - # Connect the controllers. - for i in range(len(controllers) - 1): - controllers[i].set_next(controllers[i + 1]) - controllers[-1].set_next(scheduler) # Create a frontend. frontend = Frontend( @@ -85,11 +37,12 @@ def main(args: argparse.Namespace): sampling_params = SamplingParams.from_dict(sampling_params_dict) sampling_params = frontend.add_eos_token(sampling_params) frontend.query(text, sampling_params) - scheduler.add_sequence_groups(frontend.get_inputs()) - scheduler.step() - for seq_group in scheduler.get_finished(): + server.scheduler.add_sequence_groups(frontend.get_inputs()) + server.scheduler.step() + for seq_group in server.scheduler.get_finished(): frontend.print_response(seq_group) - if not (scheduler.pending or scheduler.running or test_inputs): + if not (server.scheduler.pending or server.scheduler.running or + test_inputs): break From d262ac92f03743d07b4a3881d9ee295a3574d4b1 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 11:53:07 +0800 Subject: [PATCH 05/13] initial implementation of fastapi frontend --- cacheflow/http_frontend/fastapi_frontend.py | 144 ++++++++++++++++++++ cacheflow/master/frontend.py | 3 +- cacheflow/master/scheduler.py | 18 +-- cacheflow/master/server.py | 32 +++-- playground/http_client.py | 20 +++ playground/streaming_fastapi_worker.py | 40 ++++++ server.py | 30 +++- 7 files changed, 256 insertions(+), 31 deletions(-) create mode 100644 cacheflow/http_frontend/fastapi_frontend.py create mode 100644 playground/http_client.py create mode 100644 playground/streaming_fastapi_worker.py diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py new file mode 100644 index 000000000000..ca174c9d0b2b --- /dev/null +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -0,0 +1,144 @@ +import argparse +import asyncio +import time +from typing import List, Dict +import json + +import ray +from transformers import AutoTokenizer +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import uvicorn + +from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import Sequence, SequenceGroup +from cacheflow.master.server import (Server, add_server_arguments, + initialize_ray_cluster) +from cacheflow.worker.controller import DeviceID +from cacheflow.utils import Counter + +app = FastAPI() + +class FastAPIFrontend: + def __init__( + self, + model: str, + model_path: str, + pipeline_parallel_size: int, + tensor_parallel_size: int, + block_size: int, + dtype: str, + seed: int, + swap_space: int, + max_batch_size: int, + num_nodes: int, + num_devices_per_node: int, + distributed_init_method: str, + all_stage_devices: List[List[DeviceID]], + ): + self.block_size = block_size + + self.tokenizer = AutoTokenizer.from_pretrained(model) + self.seq_group_counter = Counter() + self.seq_counter = Counter() + remote_server_class = ray.remote(num_cpus=0)(Server) + self.server = remote_server_class( + model=model, + model_path=model_path, + pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=tensor_parallel_size, + block_size=block_size, + dtype=dtype, + seed=seed, + swap_space=swap_space, + max_batch_size=max_batch_size, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, + ) + + self.running_seq_groups: Dict[int, SequenceGroup] = {} + self.sequence_group_events: Dict[int, asyncio.Event] = {} + self.is_server_running = False + + async def server_step(self): + self.is_server_running = True + updated_seq_groups = await self.server.step.remote() + self.is_server_running = False + for seq_group in updated_seq_groups: + group_id = seq_group.group_id + self.running_seq_groups[group_id] = seq_group + self.sequence_group_events[group_id].set() + + async def generate(self, request_dict: Dict): + prompt = request_dict["prompt"] + sampling_params = SamplingParams.from_dict(request_dict) + sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id) + token_ids = self.tokenizer.encode(prompt) + seqs: List[Sequence] = [] + for _ in range(sampling_params.n): + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, token_ids, block_size=self.block_size) + seqs.append(seq) + + group_id = next(self.seq_group_counter) + seq_group = SequenceGroup(group_id, seqs) + group_event = asyncio.Event() + self.sequence_group_events[group_id] = group_event + await self.server.add_sequence_groups.remote([seq_group, sampling_params]) + while True: + if not self.is_server_running: + await self.server_step() + await group_event.wait() + seq_group = self.running_seq_groups[group_id] + all_outputs = [] + for seq in seq_group.seqs: + token_ids = seq.get_token_ids() + output = self.tokenizer.decode(token_ids, skip_special_tokens=True) + all_outputs.append(output) + ret = { + "text": all_outputs, + "error": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + if seq_group.is_finished(): + break + + +@app.post("/generate") +async def generate_stream(request: Request): + request_dict = await request.json() + return StreamingResponse(frontend.generate(request_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=10002) + parser = add_server_arguments(parser) + args = parser.parse_args() + + # TODO(zhuohan): Support pipeline parallelism. + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + + frontend = FastAPIFrontend( + model=args.model, + model_path=args.model_path, + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + block_size=args.block_size, + dtype=args.dtype, + seed=args.seed, + swap_space=args.swap_space, + max_batch_size=args.max_batch_size, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py index cb188387db2c..8134993e300a 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/frontend.py @@ -3,8 +3,7 @@ from transformers import AutoTokenizer from cacheflow.sampling_params import SamplingParams -from cacheflow.sequence import Sequence -from cacheflow.sequence import SequenceGroup +from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.utils import Counter diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 1ade007b5f22..c71d0768bcb2 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -43,8 +43,6 @@ def __init__( self.swapped: List[SequenceGroup] = [] # Pending sequence groups (FIFO). self.pending: List[SequenceGroup] = [] - # Finished sequence groups. - self.finished: List[SequenceGroup] = [] def add_sequence_groups( self, @@ -106,7 +104,7 @@ def _swap_out( seq.status = SequenceStatus.SWAPPED self.swapped.append(seq_group) - def step(self) -> None: + def step(self) -> List[SequenceGroup]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} @@ -177,6 +175,8 @@ def step(self) -> None: # 4. Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] + updated_seq_groups: List[SequenceGroup] = self.running.copy() + for seq_group in self.running: group_id = seq_group.group_id num_steps = self.num_steps[group_id] @@ -220,6 +220,8 @@ def step(self) -> None: blocks_to_copy, ) + return updated_seq_groups + def post_step( self, seq_outputs: Dict[int, SequenceOutputs], @@ -269,18 +271,12 @@ def post_step( running: List[SequenceGroup] = [] for seq_group in self.running: if seq_group.is_finished(): - self._return(seq_group) + self._free_seq_group(seq_group) else: running.append(seq_group) self.running = running - def _return(self, seq_group: SequenceGroup) -> None: + def _free_seq_group(self, seq_group: SequenceGroup) -> None: group_id = seq_group.group_id del self.num_steps[group_id] del self.sampling_params[group_id] - self.finished.append(seq_group) - - def get_finished(self) -> List[SequenceGroup]: - finished = self.finished - self.finished = [] - return finished diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index dc31f2bc55c4..a63e68d724fb 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -7,7 +7,8 @@ from cacheflow.master.scheduler import Scheduler from cacheflow.models import get_memory_analyzer from cacheflow.worker.controller import Controller, DeviceID - +from cacheflow.sequence import SequenceGroup +from cacheflow.sampling_params import SamplingParams class Server: def __init__( @@ -21,17 +22,13 @@ def __init__( seed: int, swap_space: int, max_batch_size: int, + num_nodes: int, + num_devices_per_node: int, + distributed_init_method: str, + all_stage_devices: List[List[DeviceID]], ): - # TODO(zhuohan): Support pipeline parallelism. - assert pipeline_parallel_size == 1, ( - 'Pipeline parallelism is not supported yet.') - - (self.num_nodes, self.num_devices_per_node, distributed_init_method, - all_stage_devices) = ( - initialize_ray_cluster( - pipeline_parallel_size=pipeline_parallel_size, - tensor_parallel_size=tensor_parallel_size)) - + self.num_nodes = num_nodes + self.num_devices_per_node = num_devices_per_node self.world_size = pipeline_parallel_size * tensor_parallel_size self.memory_analyzer = get_memory_analyzer( @@ -80,6 +77,19 @@ def __init__( self.controllers[i].set_next(self.controllers[i + 1]) self.controllers[-1].set_next(self.scheduler) + def add_sequence_groups( + self, + sequence_groups: List[Tuple[SequenceGroup, SamplingParams]] + ): + self.scheduler.add_sequence_groups(sequence_groups) + + def step(self): + return self.scheduler.step() + + def has_unfinished_requests(self): + return (len(self.scheduler.pending) > 0 or + len(self.scheduler.running) > 0) + def initialize_ray_cluster( address: str = 'auto', diff --git a/playground/http_client.py b/playground/http_client.py new file mode 100644 index 000000000000..ac13ac62c4b6 --- /dev/null +++ b/playground/http_client.py @@ -0,0 +1,20 @@ +import requests +import json + +def http_bot(): + prompt = "How are you? I'm fine." + + headers = {"User-Agent": "Test Client"} + pload = { + "prompt": prompt, + } + response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + +for h in http_bot(): + print(h, end="", flush=True) \ No newline at end of file diff --git a/playground/streaming_fastapi_worker.py b/playground/streaming_fastapi_worker.py new file mode 100644 index 000000000000..8ab087d109e6 --- /dev/null +++ b/playground/streaming_fastapi_worker.py @@ -0,0 +1,40 @@ +import argparse +import asyncio +import time +from typing import Union +import json + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import uvicorn + + +app = FastAPI() + + +async def text_streamer(args): + context = args["prompt"] + words = context.split(" ") + for word in words: + await asyncio.sleep(1) + print("word:", word) + ret = { + "text": word + " ", + "error": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + + +@app.post("/") +async def read_root(request: Request): + args = await request.json() + return StreamingResponse(text_streamer(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=10002) + args = parser.parse_args() + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/server.py b/server.py index 9384efc59524..f556c2bc01ff 100644 --- a/server.py +++ b/server.py @@ -2,10 +2,22 @@ from typing import List from cacheflow.master.frontend import Frontend -from cacheflow.master.server import Server, add_server_arguments +from cacheflow.master.server import (Server, add_server_arguments, + initialize_ray_cluster) from cacheflow.sampling_params import SamplingParams + def main(args: argparse.Namespace): + # TODO(zhuohan): Support pipeline parallelism. + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + # Create a server. server = Server( model=args.model, @@ -17,6 +29,10 @@ def main(args: argparse.Namespace): seed=args.seed, swap_space=args.swap_space, max_batch_size=args.max_batch_size, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, ) # Create a frontend. @@ -37,12 +53,12 @@ def main(args: argparse.Namespace): sampling_params = SamplingParams.from_dict(sampling_params_dict) sampling_params = frontend.add_eos_token(sampling_params) frontend.query(text, sampling_params) - server.scheduler.add_sequence_groups(frontend.get_inputs()) - server.scheduler.step() - for seq_group in server.scheduler.get_finished(): - frontend.print_response(seq_group) - if not (server.scheduler.pending or server.scheduler.running or - test_inputs): + server.add_sequence_groups(frontend.get_inputs()) + updated_seq_groups = server.step() + for seq_group in updated_seq_groups: + if seq_group.is_finished(): + frontend.print_response(seq_group) + if not (server.has_unfinished_requests() or test_inputs): break From ea063255e33c54cc6d0b1ca9d3d2031e9ec7bf37 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 06:18:09 +0000 Subject: [PATCH 06/13] fix memory bugs and add test client --- cacheflow/http_frontend/fastapi_frontend.py | 12 +++++++++--- cacheflow/http_frontend/test_client.py | 20 ++++++++++++++++++++ cacheflow/master/server.py | 4 ++++ cacheflow/models/memory_analyzer.py | 18 +++++++++--------- cacheflow/models/model_utils.py | 5 ++++- cacheflow/models/utils.py | 11 ----------- cacheflow/utils.py | 9 +++++++++ server.py | 4 +++- 8 files changed, 58 insertions(+), 25 deletions(-) create mode 100644 cacheflow/http_frontend/test_client.py diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index ca174c9d0b2b..e0b447349385 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -15,7 +15,7 @@ from cacheflow.master.server import (Server, add_server_arguments, initialize_ray_cluster) from cacheflow.worker.controller import DeviceID -from cacheflow.utils import Counter +from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory app = FastAPI() @@ -42,7 +42,7 @@ def __init__( self.seq_group_counter = Counter() self.seq_counter = Counter() remote_server_class = ray.remote(num_cpus=0)(Server) - self.server = remote_server_class( + self.server = remote_server_class.remote( model=model, model_path=model_path, pipeline_parallel_size=pipeline_parallel_size, @@ -56,6 +56,8 @@ def __init__( num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), ) self.running_seq_groups: Dict[int, SequenceGroup] = {} @@ -86,7 +88,7 @@ async def generate(self, request_dict: Dict): seq_group = SequenceGroup(group_id, seqs) group_event = asyncio.Event() self.sequence_group_events[group_id] = group_event - await self.server.add_sequence_groups.remote([seq_group, sampling_params]) + await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) while True: if not self.is_server_running: await self.server_step() @@ -139,6 +141,10 @@ async def generate_stream(request: Request): seed=args.seed, swap_space=args.swap_space, max_batch_size=args.max_batch_size, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/http_frontend/test_client.py b/cacheflow/http_frontend/test_client.py new file mode 100644 index 000000000000..40ae15b2e368 --- /dev/null +++ b/cacheflow/http_frontend/test_client.py @@ -0,0 +1,20 @@ +import requests +import json + +def http_bot(): + prompt = "The future of cloud computing is" + + headers = {"User-Agent": "Test Client"} + pload = { + "prompt": prompt, + } + response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + +for h in http_bot(): + print(h, flush=True) \ No newline at end of file diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index a63e68d724fb..0f09cda98a59 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -26,6 +26,8 @@ def __init__( num_devices_per_node: int, distributed_init_method: str, all_stage_devices: List[List[DeviceID]], + gpu_memory: int, + cpu_memory: int, ): self.num_nodes = num_nodes self.num_devices_per_node = num_devices_per_node @@ -35,6 +37,8 @@ def __init__( model_name=model, block_size=block_size, dtype=dtype, + gpu_memory=gpu_memory, + cpu_memory=cpu_memory, tensor_parallel_size=tensor_parallel_size, ) self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks( diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 69675588c3c4..45d6a36b9022 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -1,9 +1,7 @@ import torch from transformers import AutoConfig -from cacheflow.models.utils import get_cpu_memory from cacheflow.models.utils import get_dtype_size -from cacheflow.models.utils import get_gpu_memory _GiB = 1 << 30 @@ -31,11 +29,15 @@ def __init__( model_name: str, block_size: int, dtype: torch.dtype, + gpu_memory: int, + cpu_memory: int, tensor_parallel_size: int, ) -> None: self.model_name = model_name self.block_size = block_size self.dtype = dtype + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory self.tensor_parallel_size = tensor_parallel_size config = AutoConfig.from_pretrained(model_name) @@ -106,8 +108,7 @@ def get_max_num_gpu_blocks( memory_utilization: float = 0.95, ) -> int: # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. - gpu_memory = get_gpu_memory() - usable_memory = int(memory_utilization * gpu_memory) + usable_memory = int(memory_utilization * self.gpu_memory) param_size = self._get_param_size() act_size = self._get_max_act_size(max_num_batched_tokens) @@ -122,16 +123,15 @@ def get_max_num_cpu_blocks( swap_space: int, ) -> int: swap_space = swap_space * _GiB - cpu_memory = get_cpu_memory() - if swap_space > 0.8 * cpu_memory: + if swap_space > 0.8 * self.cpu_memory: raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' 'takes more than 80% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' + f'({self.cpu_memory / _GiB:.2f} GiB).' 'Please check the swap space size.') - if swap_space > 0.5 * cpu_memory: + if swap_space > 0.5 * self.cpu_memory: print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' 'takes more than 50% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' + f'({self.cpu_memory / _GiB:.2f} GiB).' 'This may slow the system performance.') max_num_blocks = swap_space // self._get_cache_block_size() return max_num_blocks diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index b1fdacea075a..3a2a6b2b5a35 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -44,11 +44,14 @@ def get_memory_analyzer( model_name: str, block_size: int, dtype: Union[torch.dtype, str], + gpu_memory: int, + cpu_memory: int, tensor_parallel_size: int = 1, ) -> CacheFlowMemoryAnalyzer: torch_dtype = get_torch_dtype(dtype) for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): if model_class in model_name: return memory_analyzer( - model_name, block_size, torch_dtype, tensor_parallel_size) + model_name, block_size, torch_dtype, gpu_memory, cpu_memory, + tensor_parallel_size) raise ValueError(f'Unsupported model name: {model_name}') diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py index cdad84f03831..84e7fbce6ccd 100644 --- a/cacheflow/models/utils.py +++ b/cacheflow/models/utils.py @@ -1,9 +1,5 @@ from typing import Union -import random - -import numpy as np -import psutil import torch _STR_DTYPE_TO_TORCH_DTYPE = { @@ -26,10 +22,3 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int: torch_dtype = get_torch_dtype(dtype) return torch.tensor([], dtype=torch_dtype).element_size() - -def get_gpu_memory(gpu: int = 0) -> int: - return torch.cuda.get_device_properties(gpu).total_memory - - -def get_cpu_memory() -> int: - return psutil.virtual_memory().total diff --git a/cacheflow/utils.py b/cacheflow/utils.py index e987449fd9dd..725a4d19a54f 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -1,5 +1,6 @@ import enum import random +import psutil import numpy as np import torch @@ -36,3 +37,11 @@ def set_random_seed(seed: int): if model_parallel_is_initialized(): model_parallel_cuda_manual_seed(seed) + + +def get_gpu_memory(gpu: int = 0) -> int: + return torch.cuda.get_device_properties(gpu).total_memory + + +def get_cpu_memory() -> int: + return psutil.virtual_memory().total diff --git a/server.py b/server.py index f556c2bc01ff..2c6fe285fa65 100644 --- a/server.py +++ b/server.py @@ -5,7 +5,7 @@ from cacheflow.master.server import (Server, add_server_arguments, initialize_ray_cluster) from cacheflow.sampling_params import SamplingParams - +from cacheflow.utils import get_gpu_memory, get_cpu_memory def main(args: argparse.Namespace): # TODO(zhuohan): Support pipeline parallelism. @@ -33,6 +33,8 @@ def main(args: argparse.Namespace): num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), ) # Create a frontend. From 03f09f0722829d6b0a6ad4114b36f235f75837b3 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 09:09:33 +0000 Subject: [PATCH 07/13] Rename and small fixes --- .../{test_client.py => test_cli_client.py} | 11 +++++++---- cacheflow/master/server.py | 2 ++ cacheflow/master/{frontend.py => simple_frontend.py} | 2 +- server.py | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) rename cacheflow/http_frontend/{test_client.py => test_cli_client.py} (71%) rename cacheflow/master/{frontend.py => simple_frontend.py} (98%) diff --git a/cacheflow/http_frontend/test_client.py b/cacheflow/http_frontend/test_cli_client.py similarity index 71% rename from cacheflow/http_frontend/test_client.py rename to cacheflow/http_frontend/test_cli_client.py index 40ae15b2e368..217f8088645a 100644 --- a/cacheflow/http_frontend/test_client.py +++ b/cacheflow/http_frontend/test_cli_client.py @@ -1,12 +1,15 @@ import requests import json -def http_bot(): - prompt = "The future of cloud computing is" +def http_request(): + prompt = "Ion Stoica is a" headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, + "n": 4, + "use_beam_search": True, + "temperature": 0.0, } response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) @@ -16,5 +19,5 @@ def http_bot(): output = data["text"] yield output -for h in http_bot(): - print(h, flush=True) \ No newline at end of file +for h in http_request(): + print(h, flush=True) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 0f09cda98a59..69d3708d1455 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -81,6 +81,8 @@ def __init__( self.controllers[i].set_next(self.controllers[i + 1]) self.controllers[-1].set_next(self.scheduler) + print("Server initialized.") + def add_sequence_groups( self, sequence_groups: List[Tuple[SequenceGroup, SamplingParams]] diff --git a/cacheflow/master/frontend.py b/cacheflow/master/simple_frontend.py similarity index 98% rename from cacheflow/master/frontend.py rename to cacheflow/master/simple_frontend.py index 8134993e300a..3e3fa5987b25 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/simple_frontend.py @@ -7,7 +7,7 @@ from cacheflow.utils import Counter -class Frontend: +class SimpleFrontend: def __init__( self, diff --git a/server.py b/server.py index 2c6fe285fa65..a7dd927dd2f0 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ import argparse from typing import List -from cacheflow.master.frontend import Frontend +from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, initialize_ray_cluster) from cacheflow.sampling_params import SamplingParams @@ -38,7 +38,7 @@ def main(args: argparse.Namespace): ) # Create a frontend. - frontend = Frontend( + frontend = SimpleFrontend( model_name=args.model, block_size=args.block_size, ) From f045128f52337e19a94ca6ed96100e1807c16be2 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 09:17:55 +0000 Subject: [PATCH 08/13] Rename and update readme --- README.md | 27 +++++++++++++++++++++++++-- server.py => simple_server.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) rename server.py => simple_server.py (97%) diff --git a/README.md b/README.md index 7008d8a204cb..2745a02a1786 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,32 @@ pip install flash-attn # This may take up to 10 mins. pip install -e . ``` -## Run +## Test simple server ```bash ray start --head -python server.py [--tensor-parallel-size ] +python simple_server.py +``` + +The detailed arguments for `simple_server.py` can be found by +```bash +python simple_server.py --help +``` + +## FastAPI server + +Install the following additional dependencies: +```bash +pip install fastapi uvicorn +``` + +To start the server: +```bash +ray start --head +python -m cacheflow.http_frontend.fastapi_frontend +``` + +To test the server: +```bash +python -m cacheflow.http_frontend.test_cli_client ``` diff --git a/server.py b/simple_server.py similarity index 97% rename from server.py rename to simple_server.py index a7dd927dd2f0..4d6aa93b97b8 100644 --- a/server.py +++ b/simple_server.py @@ -65,7 +65,7 @@ def main(args: argparse.Namespace): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='CacheFlow server') + parser = argparse.ArgumentParser(description='CacheFlow simple server.') parser = add_server_arguments(parser) args = parser.parse_args() main(args) From 6f735faf6bf618040992ec85cc9bb06ac70e209f Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 09:19:48 +0000 Subject: [PATCH 09/13] fix --- cacheflow/master/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 69d3708d1455..01c779384eb1 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -179,4 +179,4 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') - return parser \ No newline at end of file + return parser From f1666e84aa5c9e62c89d12a2bc34a2c7859cf15e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 13:16:23 +0000 Subject: [PATCH 10/13] fix api and add gradio webserver --- cacheflow/http_frontend/fastapi_frontend.py | 4 +- cacheflow/http_frontend/gradio_webserver.py | 43 +++++++++++++++++++++ cacheflow/master/server.py | 2 - 3 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 cacheflow/http_frontend/gradio_webserver.py diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index e0b447349385..dff7f7526ac6 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -92,7 +92,9 @@ async def generate(self, request_dict: Dict): while True: if not self.is_server_running: await self.server_step() - await group_event.wait() + # Wait for new output. Add a 1s timeout to prevent dead lock. + await asyncio.wait_for(group_event.wait(), timeout=1) + group_event.clear() seq_group = self.running_seq_groups[group_id] all_outputs = [] for seq in seq_group.seqs: diff --git a/cacheflow/http_frontend/gradio_webserver.py b/cacheflow/http_frontend/gradio_webserver.py new file mode 100644 index 000000000000..9407b6e35555 --- /dev/null +++ b/cacheflow/http_frontend/gradio_webserver.py @@ -0,0 +1,43 @@ +import argparse +import json +import time + +import gradio as gr +import requests + + +def http_bot(prompt): + headers = {"User-Agent": "Cacheflow Client"} + pload = { + "prompt": prompt, + "max_num_steps": 128, + } + response = requests.post(args.model_url, headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"][0] + yield output + + +def build_demo(): + with gr.Blocks() as demo: + gr.Markdown( + "# Cacheflow demo (OPT-13B)\n" + ) + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False) + outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") + inputbox.submit(http_bot, [inputbox], [outputbox]) + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=10003) + parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate") + args = parser.parse_args() + + demo = build_demo() + demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port) \ No newline at end of file diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 01c779384eb1..9891dce84a4a 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -81,8 +81,6 @@ def __init__( self.controllers[i].set_next(self.controllers[i + 1]) self.controllers[-1].set_next(self.scheduler) - print("Server initialized.") - def add_sequence_groups( self, sequence_groups: List[Tuple[SequenceGroup, SamplingParams]] From 9b2972feec1950857344511b19579c14331fce8e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 13:22:17 +0000 Subject: [PATCH 11/13] Modify readme --- README.md | 16 +++++++++++++++- cacheflow/http_frontend/gradio_webserver.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2745a02a1786..1898fbe57f67 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ ray start --head python simple_server.py ``` -The detailed arguments for `simple_server.py` can be found by +The detailed arguments for `simple_server.py` can be found by: ```bash python simple_server.py --help ``` @@ -37,3 +37,17 @@ To test the server: ```bash python -m cacheflow.http_frontend.test_cli_client ``` + +## GradIO web server + +Install the following additional dependencies: +```bash +pip install gradio +``` + +Start the server: +```bash +python -m cacheflow.http_frontend.fastapi_frontend +# At another terminal +python -m cacheflow.http_frontend.gradio_webserver +``` diff --git a/cacheflow/http_frontend/gradio_webserver.py b/cacheflow/http_frontend/gradio_webserver.py index 9407b6e35555..290496da3edf 100644 --- a/cacheflow/http_frontend/gradio_webserver.py +++ b/cacheflow/http_frontend/gradio_webserver.py @@ -24,7 +24,7 @@ def http_bot(prompt): def build_demo(): with gr.Blocks() as demo: gr.Markdown( - "# Cacheflow demo (OPT-13B)\n" + "# Cacheflow demo\n" ) inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False) outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") From 9d23237e1a334f5953fe2f465ddaa5479969d542 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 Mar 2023 14:10:05 +0000 Subject: [PATCH 12/13] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1898fbe57f67..41fd7107a116 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ To test the server: python -m cacheflow.http_frontend.test_cli_client ``` -## GradIO web server +## Gradio web server Install the following additional dependencies: ```bash From 739f599282a238188f2b703b3c0b06923aa102c1 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 29 Mar 2023 06:47:52 +0000 Subject: [PATCH 13/13] Address review comments. --- cacheflow/master/server.py | 4 ++-- cacheflow/sampling_params.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 9891dce84a4a..92b9858c375f 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -91,8 +91,8 @@ def step(self): return self.scheduler.step() def has_unfinished_requests(self): - return (len(self.scheduler.pending) > 0 or - len(self.scheduler.running) > 0) + return (self.scheduler.pending or self.scheduler.running or + self.scheduler.swapped) def initialize_ray_cluster( diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index e7f118bba951..4daeaa486e56 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -1,4 +1,4 @@ -from typing import Optional, Set +from typing import Optional, Set, Dict class SamplingParams: @@ -71,7 +71,7 @@ def __repr__(self) -> str: f'context_window_size={self.context_window_size})') @classmethod - def from_dict(cls, d: dict) -> 'SamplingParams': + def from_dict(cls, d: Dict) -> 'SamplingParams': return cls( n=d.get('n', 1), temperature=d.get('temperature', 1.0),