diff --git a/.gitignore b/.gitignore index a050864d9ddd..da5a337c4683 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,6 @@ cython_debug/ # Python pickle files *.pkl + +# Sphinx documentation +_build/ diff --git a/README.md b/README.md index b4d4a0a2579a..372984e03871 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 - -- [2023/06] We officially released vLLM! vLLM has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid April. Check out our [blog post](https://vllm.ai). +- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. +- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). --- @@ -28,7 +28,7 @@ vLLM is fast with: - State-of-the-art serving throughput - Efficient management of attention key and value memory with **PagedAttention** -- Dynamic batching of incoming requests +- Continuous batching of incoming requests - Optimized CUDA kernels vLLM is flexible and easy to use with: @@ -42,7 +42,8 @@ vLLM is flexible and easy to use with: vLLM seamlessly supports many Huggingface models, including the following architectures: - GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPTNeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) +- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) +- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 1a215bdd552c..48cdffb8afcc 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -17,6 +17,7 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM( model=args.model, + tokenizer=args.tokenizer, tensor_parallel_size=args.tensor_parallel_size, max_num_seqs=args.batch_size, max_num_batched_tokens=args.batch_size * args.input_len, @@ -63,6 +64,7 @@ def run_to_completion(profile: bool = False): description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b8d824c7db74..b0705ec0fe80 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -24,20 +24,13 @@ import aiohttp import numpy as np -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase +from vllm.transformers_utils.tokenizer import get_tokenizer # (prompt len, output len, latency) REQUEST_LATENCY: List[Tuple[int, int, float]] = [] -def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase: - config = AutoConfig.from_pretrained(model_name) - if config.model_type == "llama": - # A workaround for potential protobuf errors. - model_name = "hf-internal-testing/llama-tokenizer" - return AutoTokenizer.from_pretrained(model_name) - - def sample_requests( dataset_path: str, num_requests: int, @@ -217,7 +210,7 @@ def main(args: argparse.Namespace): parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "tgi"]) parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--port", type=int, default=8000) parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") parser.add_argument("--tokenizer", type=str, required=True, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c40145abcc00..9ae5aa9b42da 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,23 +6,11 @@ from typing import List, Tuple import torch -from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM, - PreTrainedTokenizerBase) +from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from tqdm import tqdm from vllm import LLM, SamplingParams - - -def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase: - config = AutoConfig.from_pretrained(model_name) - if config.model_type == "llama": - # A workaround for potential protobuf errors. - model_name = "hf-internal-testing/llama-tokenizer" - tokenizer = AutoTokenizer.from_pretrained(model_name) - # To enable padding in the HF backend. - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - return AutoTokenizer.from_pretrained(model_name) +from vllm.transformers_utils.tokenizer import get_tokenizer def sample_requests( @@ -74,6 +62,7 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], model: str, + tokenizer: str, tensor_parallel_size: int, seed: int, n: int, @@ -81,6 +70,7 @@ def run_vllm( ) -> float: llm = LLM( model=model, + tokenizer=tokenizer, tensor_parallel_size=tensor_parallel_size, seed=seed, ) @@ -118,9 +108,10 @@ def run_hf( max_batch_size: int, ) -> float: assert not use_beam_search - tokenizer = get_tokenizer(model) - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16) + llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token llm = llm.cuda() pbar = tqdm(total=len(requests)) @@ -170,13 +161,13 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. - tokenizer = get_tokenizer(args.model) + tokenizer = get_tokenizer(args.tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.backend == "vllm": elapsed_time = run_vllm( - requests, args.model, args.tensor_parallel_size, args.seed, args.n, - args.use_beam_search) + requests, args.model, args.tokenizer, args.tensor_parallel_size, + args.seed, args.n, args.use_beam_search) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -198,6 +189,7 @@ def main(args: argparse.Namespace): parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.") @@ -208,11 +200,14 @@ def main(args: argparse.Namespace): parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.") args = parser.parse_args() + if args.backend == "vllm": if args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") elif args.backend == "hf": if args.hf_max_batch_size is None: raise ValueError("HF max batch size is required for HF backend.") + if args.tokenizer is None: + args.tokenizer = args.model main(args) diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index 828ca0219353..bdb25b78d85b 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -1,6 +1,6 @@ #!/bin/bash -PORT=8001 +PORT=8000 MODEL=$1 TOKENS=$2 diff --git a/docs/README.md b/docs/README.md index 9b164017d041..46488c9bb0b9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -4,14 +4,14 @@ ```bash # Install dependencies. -pip -r requirements-docs.txt +pip install -r requirements-docs.txt # Build the docs. make clean make html ``` -## Open the docs with your brower +## Open the docs with your browser ```bash python -m http.server -d build/html/ diff --git a/docs/source/index.rst b/docs/source/index.rst index ab2e17a93b04..6fbfd01d72e3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,7 +29,7 @@ vLLM is fast with: * State-of-the-art serving throughput * Efficient management of attention key and value memory with **PagedAttention** -* Dynamic batching of incoming requests +* Continuous batching of incoming requests * Optimized CUDA kernels vLLM is flexible and easy to use with: @@ -40,7 +40,11 @@ vLLM is flexible and easy to use with: * Streaming outputs * OpenAI-compatible API server -For more information, please refer to our `blog post `_. +For more information, check out the following: + +* `vLLM announcing blog post `_ (intro to PagedAttention) +* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency `_ by Cade Daniel et al. + Documentation @@ -53,6 +57,12 @@ Documentation getting_started/installation getting_started/quickstart +.. toctree:: + :maxdepth: 1 + :caption: Serving + + serving/distributed_serving + .. toctree:: :maxdepth: 1 :caption: Models diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a9dad0fedf3e..b8f7f9e000b1 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -17,6 +17,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. + * - :code:`GPTBigCodeForCausalLM` + - StarCoder, SantaCoder, WizardCoder + - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst new file mode 100644 index 000000000000..4f36dca15d7d --- /dev/null +++ b/docs/source/serving/distributed_serving.rst @@ -0,0 +1,38 @@ +.. _distributed_serving: + +Distributed Inference and Serving +================================= + +vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with `Ray `_. To run distributed inference, install Ray with: + +.. code-block:: console + + $ pip install ray + +To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs: + +.. code-block:: python + + from vllm import LLM + llm = LLM("facebook/opt-13b", tensor_parallel_size=4) + output = llm.generate("San Franciso is a") + +To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs: + +.. code-block:: console + + $ python -m vllm.entrypoints.api_server \ + $ --model facebook/opt-13b \ + $ --tensor-parallel-size 4 + +To scale vLLM beyond a single machine, start a `Ray runtime `_ via CLI before running vLLM: + +.. code-block:: console + + $ # On head node + $ ray start --head + + $ # On worker nodes + $ ray start --address= + +After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines. \ No newline at end of file diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index 932a1151553f..f95080ac610a 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -30,7 +30,7 @@ def main(args: argparse.Namespace): request_outputs = engine.step() for request_output in request_outputs: - if request_output.finished(): + if request_output.finished: print(request_output) if not (engine.has_unfinished_requests() or test_prompts): diff --git a/setup.py b/setup.py index 0aa666e7b679..2644c03b1637 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,9 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if not torch.cuda.is_available(): +if CUDA_HOME is None: raise RuntimeError( - f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. " - "CUDA must be available in order to build the package.") + f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.") def get_nvcc_cuda_version(cuda_dir: str) -> Version: diff --git a/vllm/__init__.py b/vllm/__init__.py index 15ea2dc29f18..0c5ea9698576 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -6,7 +6,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.1.0" +__version__ = "0.1.1" __all__ = [ "LLM", diff --git a/vllm/config.py b/vllm/config.py index fd9550d524d7..c2eba07d4e24 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,6 +16,9 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. use_np_weights: Save a numpy copy of model weights for faster loading. @@ -30,6 +33,8 @@ class ModelConfig: def __init__( self, model: str, + tokenizer: str, + tokenizer_mode: str, download_dir: Optional[str], use_np_weights: bool, use_dummy_weights: bool, @@ -37,6 +42,8 @@ def __init__( seed: int, ) -> None: self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode self.download_dir = download_dir self.use_np_weights = use_np_weights self.use_dummy_weights = use_dummy_weights @@ -44,6 +51,15 @@ def __init__( self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self._verify_tokenizer_mode() + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode def verify_with_parallel_config( self, @@ -170,14 +186,18 @@ class SchedulerConfig: a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. + max_seq_len: Maximum length of a sequence (including prompt + and generated text). """ def __init__( self, max_num_batched_tokens: int, max_num_seqs: int, + max_seq_len: int ) -> None: self.max_num_batched_tokens = max_num_batched_tokens self.max_num_seqs = max_num_seqs + self.max_seq_len = max_seq_len _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 500c5ddd943d..7160c4f43fd7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -102,11 +102,12 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: + def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} + ignored_seq_groups: List[SequenceGroup] = [] # Fix the current time. now = time.time() @@ -156,7 +157,9 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - num_curr_seqs = len(self.running) + num_curr_seqs = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break @@ -185,12 +188,24 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: # If the sequence group has been preempted in this step, stop. if seq_group in preempted: break + + num_prompt_tokens = seq_group.get_seqs()[0].get_len() + if num_prompt_tokens >= self.scheduler_config.max_seq_len: + logger.warn( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + " and exceeds limit of " + f"{self.scheduler_config.max_seq_len}") + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + break + # If the sequence group cannot be allocated, stop. if not self.block_manager.can_allocate(seq_group): break # If the number of batched tokens exceeds the limit, stop. - num_prompt_tokens = seq_group.get_seqs()[0].get_len() if (num_batched_tokens + num_prompt_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -198,7 +213,9 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING) - num_curr_seqs = len(self.running) + num_curr_seqs = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break @@ -214,7 +231,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: blocks_to_copy=blocks_to_copy, ) if not self.log_stats: - return scheduler_outputs, prompt_group_ids + return scheduler_outputs, prompt_group_ids, ignored_seq_groups # TODO(woosuk): Move the below code to the engine. now = time.time() @@ -254,13 +271,13 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: f"Pending: {len(self.waiting)} reqs, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") - return scheduler_outputs, prompt_group_ids + return scheduler_outputs, prompt_group_ids, ignored_seq_groups - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. - scheduler_outputs, prompt_group_ids = self._schedule() + scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule() # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] @@ -282,7 +299,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: block_tables=block_tables, ) seq_group_metadata_list.append(seq_group_metadata) - return seq_group_metadata_list, scheduler_outputs + return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups def update( self, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 10e6070b42c7..e32ef005d0d2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,6 +11,8 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = "auto" download_dir: Optional[str] = None use_np_weights: bool = False use_dummy_weights: bool = False @@ -27,6 +29,8 @@ class EngineArgs: disable_log_stats: bool = False def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) @staticmethod @@ -37,6 +41,14 @@ def add_cli_args( # Model arguments parser.add_argument('--model', type=str, default='facebook/opt-125m', help='name or path of the huggingface model to use') + parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer, + help='name or path of the huggingface tokenizer to use') + parser.add_argument('--tokenizer-mode', type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow'], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + 'always use the slow tokenizer.') parser.add_argument('--download-dir', type=str, default=EngineArgs.download_dir, help='directory to download and load the weights, ' @@ -104,15 +116,19 @@ def create_engine_configs( ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: # Initialize the configs. model_config = ModelConfig( - self.model, self.download_dir, self.use_np_weights, - self.use_dummy_weights, self.dtype, self.seed) + self.model, self.tokenizer, self.tokenizer_mode, self.download_dir, + self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray) + max_seq_len = min( + self.max_num_batched_tokens, + getattr(model_config.hf_config, "max_position_embeddings", + float("inf"))) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs) + self.max_num_seqs, max_seq_len) return model_config, cache_config, parallel_config, scheduler_config diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e74f6b418fb2..1b0016e5831c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -154,7 +154,7 @@ async def generate( yield request_output # Once finished, release the resources of the sequence group. - if request_output.finished(): + if request_output.finished: if self.log_requests: logger.info(f"Finished request {request_id}.") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3668dd7ee37f..c9b4215d07c2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,11 +6,12 @@ from vllm.core.scheduler import Scheduler from vllm.engine.arg_utils import EngineArgs from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray -from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.transformers_utils.tokenizer import (detokenize_incrementally, + get_tokenizer) from vllm.utils import Counter from vllm.worker.worker import Worker @@ -59,6 +60,8 @@ def __init__( logger.info( "Initializing an LLM engine with config: " f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + f"tokenizer_mode={model_config.tokenizer_mode}, " f"dtype={model_config.dtype}, " f"use_dummy_weights={model_config.use_dummy_weights}, " f"download_dir={model_config.download_dir!r}, " @@ -75,7 +78,8 @@ def __init__( self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer(model_config.model) + self.tokenizer = get_tokenizer(model_config.tokenizer, + model_config.tokenizer_mode) self.seq_counter = Counter() # Create the parallel GPU workers. @@ -127,6 +131,12 @@ def _init_cache(self) -> None: # FIXME(woosuk): Change to debug log. logger.info(f'# GPU blocks: {num_gpu_blocks}, ' f'# CPU blocks: {num_cpu_blocks}') + + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -216,8 +226,8 @@ def step(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): + seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule() + if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups): # Nothing to do. return [] @@ -241,7 +251,7 @@ def step(self) -> List[RequestOutput]: # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in seq_groups: + for seq_group in seq_groups + ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) return request_outputs @@ -278,6 +288,12 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: if stopped: continue + # Check if the sequence has reached max_seq_len. + if (seq.get_len() >= + self.scheduler.scheduler_config.max_seq_len): + self.scheduler.free_seq( + seq, SequenceStatus.FINISHED_LENGTH_CAPPED) + continue # Check if the sequence has reached max_tokens. if seq.get_output_len() == sampling_params.max_tokens: self.scheduler.free_seq( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 19510d9050a7..a78e00df0abd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -25,6 +25,9 @@ class LLM: Args: model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -38,6 +41,8 @@ class LLM: def __init__( self, model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", tensor_parallel_size: int = 1, dtype: str = "auto", seed: int = 0, @@ -47,6 +52,8 @@ def __init__( kwargs["disable_log_stats"] = True engine_args = EngineArgs( model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, tensor_parallel_size=tensor_parallel_size, dtype=dtype, seed=seed, @@ -60,6 +67,12 @@ def get_tokenizer( ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + def generate( self, prompts: Optional[Union[str, List[str]]] = None, @@ -133,7 +146,7 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: - if output.finished(): + if output.finished: outputs.append(output) if use_tqdm: pbar.update(1) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bb417eaf3c2e..a354ab0d5a05 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -15,7 +15,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.tokenizer_utils import get_tokenizer from vllm.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, @@ -23,6 +22,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -313,7 +313,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: engine = AsyncLLMEngine.from_engine_args(engine_args) # A separate tokenizer to map token IDs to strings. - tokenizer = get_tokenizer(args.model) + tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode) uvicorn.run(app, host=args.host, port=args.port, log_level="info", timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index a0b3f6ff653a..ce41cfe0ad11 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -8,6 +8,7 @@ "gelu": nn.GELU(), "gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "relu": nn.ReLU(), } diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index fe4b4094a5db..2ac73aea9d35 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceOutputs +_SAMPLING_EPS = 1e-5 class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -74,7 +75,7 @@ def forward( # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == probs.shape[0] - if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks): + if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks): probs = _apply_top_p_top_k(probs, top_ps, top_ks) # Sample the next tokens. @@ -152,7 +153,7 @@ def _apply_penalties( continue p = presence_penalties[i] f = frequency_penalties[i] - if p == 0.0 and f == 0.0: + if p < _SAMPLING_EPS and f < _SAMPLING_EPS: continue indices.append(i) @@ -190,7 +191,7 @@ def _get_temperatures( for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature - if temperature == 0.0: + if temperature < _SAMPLING_EPS: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. @@ -286,7 +287,7 @@ def _sample_from_prompt( beam_width = sampling_params.best_of _, next_token_ids = torch.topk(prob, beam_width) next_token_ids = next_token_ids.tolist() - elif sampling_params.temperature == 0.0: + elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. assert sampling_params.best_of == 1 next_token_id = torch.argmax(prob) @@ -343,7 +344,7 @@ def _sample_from_generation_tokens( parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] - elif sampling_params.temperature == 0.0: + elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. assert len(seq_ids) == 1 next_token_id = torch.argmax(probs, dim=-1) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ec5b75d08d8b..78a8abc5b0ba 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -6,15 +6,18 @@ from transformers import PretrainedConfig from vllm.config import ModelConfig -from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM, - LlamaForCausalLM, OPTForCausalLM) +from vllm.model_executor.models import (CodeGenForCausalLM, GPT2LMHeadModel, GPTBigCodeForCausalLM, + GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM) from vllm.model_executor.weight_utils import initialize_dummy_weights # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { + "CodeGenForCausalLM": CodeGenForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, + "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, + "LLaMAForCausalLM": LlamaForCausalLM, "OPTForCausalLM": OPTForCausalLM, } diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 098b7d448b60..ae1de86a52b0 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,11 +1,16 @@ -from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM +from vllm.model_executor.models.codegen import CodeGenForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel +from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM +from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM + __all__ = [ + "CodeGenForCausalLM", "GPT2LMHeadModel", + "GPTBigCodeForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", "OPTForCausalLM", diff --git a/vllm/model_executor/models/codegen.py b/vllm/model_executor/models/codegen.py new file mode 100644 index 000000000000..b5dd305962f8 --- /dev/null +++ b/vllm/model_executor/models/codegen.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team, Michael Feil +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only CodeGen model compatible with HuggingFace weights. +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import CodeGenConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs +import numpy as np + + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +class CodegenGPTJStyleAttention(nn.Module): + + def __init__(self, config: CodeGenConfig): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert self.head_size * tensor_model_parallel_world_size != self.hidden_size + + self.q_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.k_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.v_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.out_projection = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=False, + perform_initialization=False) + + scaling = self.head_size ** -0.5 + rotary_dim = config.rotary_dim + self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, + scaling, rotary_dim) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + k_cache, v_cache = kv_cache + attn_output = self.attn( + position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) + attn_output, _ = self.out_projection(attn_output) + return attn_output + + +class CodeGenMLP(nn.Module): + + def __init__(self, intermediate_size: int, config: CodeGenConfig): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear(hidden_size, + intermediate_size, + gather_output=False, + perform_initialization=False) + self.fc_out = RowParallelLinear(intermediate_size, + hidden_size, + input_is_parallel=True, + perform_initialization=False) + + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class CodeGenBlock(nn.Module): + + def __init__(self, config: CodeGenConfig): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = CodegenGPTJStyleAttention(config) + self.mlp = CodeGenMLP(inner_dim, config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + residual + feed_forward_hidden_states + return hidden_states + + +class CodeGenModel(nn.Module): + + def __init__(self, config: CodeGenConfig): + super().__init__() + self.config = config + + self.embed_dim = config.n_embd + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, + self.embed_dim, + perform_initialization=False) + self.h = nn.ModuleList( + [GPTJBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + position_ids: torch.Tensor, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + cache_event + ) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class CodeGenForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.transformer = CodeGenModel(config) + self.lm_head = ColumnParallelLinear(config.n_embd, + config.vocab_size, + gather_output=False, + perform_initialization=False) + + self.sampler = Sampler(config.vocab_size) + + + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, token_type_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = ["self.wte.weight", "lm_head.weight", "fc_in.weight", "fc_in.bias"] + _row_parallel_weights = ["out_projection.weight", "fc_out.weight"] + + def load_weights(self, model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + mp_num = 4 + if hasattr(self.config, "head_dim") and self.config.head_dim in [128, 256]: + # models forked from "Salesforce/codegen2-1B" and "Salesforce/codegen2-3_7B" + # use a special setting of mp_num=8, all other using 4 + # these model.config's use a special setting of head_dim + mp_num = 8 + base_permutation = np.arange(0, mp_num * 3).reshape(-1, 3).T.flatten().tolist() + embed_dim = self.config.n_embd + local_dim = embed_dim // mp_num + permutation = np.concatenate( + [np.arange(i * local_dim, (i + 1) * local_dim) for i in base_permutation] + ) + permutation = torch.from_numpy(permutation) + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if ("attn.bias" in name or "attn.masked_bias" in name): + continue + param = state_dict[name] + if "qkv_proj" in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + + num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // num_heads + + # CodeGen to GPT-J -> permute and split + loaded_weight = loaded_weight[permutation, :] + qw, vw, kw = torch.split(loaded_weight, embed_dim, dim=0) + # continue GPT-J-style + + for name_p, weight in [("q_proj",qw),("k_proj",kw),("v_proj",vw)]: + # TODO: check if this works. + new_name = name.replace('qkv_proj', name_p) + if 'qkv_proj.weight' in name: + weight = weight.view(-1, 3, head_size, hidden_size) + weight = weight.transpose(0, 1) + weight = weight.reshape(-1, hidden_size) + elif 'qkv_proj.bias' in name: + weight = weight.view(-1, 3, head_size) + weight = weight.transpose(0, 1) + weight = weight.reshape(-1) + else: + raise ValueError(f"Unexpected weight name: {name}") + self._replace(state_dict, weight, new_name) + param = state_dict[name] + load_tensor_parallel_weights(param, weight, name.replace('qkv_proj', name_p), + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + state_dict.pop(name) + else: + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + @staticmethod + def _replace(state_dict, weights, model_name): + state_dict[model_name].copy_(weights.detach()) + +if __name__ == "__main__": + from vllm import LLM, SamplingParams + prompts = [ + "def hello", + "def bye(name:" + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model="Salesforce/codegen-350M-mono") + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/model_executor/models/codegen_template.py b/vllm/model_executor/models/codegen_template.py new file mode 100644 index 000000000000..622b7cde2fa3 --- /dev/null +++ b/vllm/model_executor/models/codegen_template.py @@ -0,0 +1,82 @@ +try: + @register_loader("CodeGenConfig") + class CodeGenLoader(ModelLoader): + @property + def architecture_name(self): + return "CodeGenForCausalLM" + + def get_model_spec(self, model): + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + model.config.n_layer, + model.config.n_head, + pre_norm=True, + activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function], + rotary_dim=model.config.rotary_dim, + rotary_interleave=False, + parallel_residual=True, + shared_layer_norm=True, + ) + + mp_num = 4 + if hasattr(model.config, "head_dim") and model.config.head_dim in [128, 256]: + # models forked from "Salesforce/codegen2-1B" and "Salesforce/codegen2-3_7B" + # use a special setting of mp_num=8, all other using 4 + # these model.config's use a special setting of head_dim + mp_num = 8 + + self.set_decoder( + spec.decoder, + model.transformer, + model.config.rotary_dim, + model.config.n_head, + model.config.n_embd, + mp_num=mp_num, + ) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + + def set_decoder(self, spec, module, rotary_dim, num_heads, embed_dim, mp_num): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.wte) + self.set_layer_norm(spec.layer_norm, module.ln_f) + + base_permutation = np.arange(0, mp_num * 3).reshape(-1, 3).T.flatten().tolist() + local_dim = embed_dim // mp_num + permutation = np.concatenate( + [np.arange(i * local_dim, (i + 1) * local_dim) for i in base_permutation] + ) + + for layer_spec, layer in zip(spec.layer, module.h): + self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln_1) + # [start convert CodeGen to GPT-J format] + # numpy conversion, adapted from torch code in + # see https://github.com/fauxpilot/fauxpilot/blob/fb4073a9078dd001ebeb7dfefb8cb2ecc8a88f4b/converter/codegen_gptj_convert.py # noqa + qkv_proj = layer.attn.qkv_proj.weight.numpy() + + # GPT-J and CodeGen slice up the qkv projection slightly differently. + # the following permutation brings Codegen 'qkv_proj' + # in GPT-J order of qw, vw, kw + # we permute the *rows* here because the computation is xA.T + new_qkv_proj = qkv_proj[permutation, :] + # the name QKV is misleading here; they are actually stored in QVK + qw, vw, kw = np.array_split(new_qkv_proj, 3, axis=0) + # [end convert CodeGen to GPT-J.] + + qw = utils.permute_for_sliced_rotary(qw, num_heads, rotary_dim) + kw = utils.permute_for_sliced_rotary(kw, num_heads, rotary_dim) + + layer_spec.self_attention.linear[0].weight = np.concatenate((qw, kw, vw)) + self.set_linear(layer_spec.self_attention.linear[1], layer.attn.out_proj) + + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc_in) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc_out) +except: + pass \ No newline at end of file diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 9537daa235ae..6c2890015c8c 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -228,11 +228,13 @@ def load_weights(self, model_name_or_path: str, # GPT-2 ties the weights of the embedding layer and the final # linear layer. continue - if ".attn.bias" in name: + if ".attn.bias" in name or ".attn.masked_bias" in name: # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue - name = "transformer." + name + + if not name.startswith("transformer."): + name = "transformer." + name # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 000000000000..d900b0f9bcfb --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +import numpy as np +from transformers import GPTBigCodeConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTBigCodeAttention(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim ** -0.5 + + self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, + bias=True, gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, + bias=True, input_is_parallel=True, + perform_initialization=False) + self.attn = PagedAttention(self.num_heads, self.head_dim, + scale=self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, + bias=True, gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(intermediate_size, hidden_size, + bias=True, input_is_parallel=True, + perform_initialization=False) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.config = config + assert config.add_cross_attention == False + + self.embed_dim = config.hidden_size + + # Optimization: While the vocab size of GPT-2 is 50257, we extend it + # to 50304 in order to make it divisible by 64. + # This improves performance since GPUs are faster if the dimension + # is divisible by 64. In addition, it allows us to shard the embedding + # layer across 2, 4, 8, or more GPUs. + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + hidden_states, kv_caches[i], input_metadata, cache_event) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.config = config + self.transformer = GPTBigCodeModel(config) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head_weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] + _row_parallel_weights = ["c_proj.weight"] + + def load_weights(self, model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + + param = state_dict[name] + + if not name.startswith("transformer."): + name = "transformer." + name + + if name == "transformer.wte.weight": + # Consider padding in the vocab size. + padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + def _expand_mqa_mha(qkv_array, n_head, head_dim): + """manipulates along axis=0 from MQA to MHA + inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim) + with n_heads for q, then 1 for k, 1 for 1 v, times head dim + return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim) + + TODO: this function is no longer needed once vllm supports MQA. + """ + qkv_array = qkv_array.numpy() + + dims_q = n_head * head_dim + q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) + # q is fine, but k & v have not replicated shape along the first axis + # as long as MQA is not nativly supported, increase memory and replicated + # (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim) + if k.ndim == 2 and v.ndim == 2: + replication = (n_head, 1) # weights + else: + replication = n_head # biases + # replicate n_head times for q, v + k, v = np.tile(k, replication), np.tile(v, replication) + # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) + # to (3 * n_heads * head_dim, hidden_dim) + qkv_array = np.concatenate((q, k, v), axis=0) + return torch.from_numpy(qkv_array) + + # For the fused QKV linear layer, manually shard the weights. + if "c_attn" in name: + # GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. + # When tensor parallelism is used, we shard the weights along the head dimension. + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tensor_model_parallel_world_size + head_start = tensor_model_parallel_rank * num_heads + head_end = (tensor_model_parallel_rank + 1) * num_heads + + if name.endswith(".weight"): + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) + loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif name.endswith(".bias"): + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected parameter name {name}") + + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 000000000000..ea097cea3d49 --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights. +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +class GPTJAttention(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert self.head_size * tensor_model_parallel_world_size != self.hidden_size + + self.q_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.k_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.v_proj = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + + self.out_projection = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=False, + perform_initialization=False) + + scaling = self.head_size ** -0.5 + rotary_dim = config.rotary_dim + self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, + scaling, rotary_dim) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + k_cache, v_cache = kv_cache + attn_output = self.attn( + position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) + attn_output, _ = self.out_projection(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__(self, intermediate_size: int, config: GPTJConfig): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear(hidden_size, + intermediate_size, + gather_output=False, + perform_initialization=False) + self.fc_out = RowParallelLinear(intermediate_size, + hidden_size, + input_is_parallel=True, + perform_initialization=False) + + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config) + self.mlp = GPTJMLP(inner_dim, config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + residual + feed_forward_hidden_states + return hidden_states + + +class GPTJModel(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + self.config = config + + self.embed_dim = config.n_embd + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, + self.embed_dim, + perform_initialization=False) + self.h = nn.ModuleList( + [GPTJBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + position_ids: torch.Tensor, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + cache_event + ) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class GPTJForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.transformer = GPTJModel(config) + self.lm_head = ColumnParallelLinear(config.n_embd, + config.vocab_size, + gather_output=False, + perform_initialization=False) + + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, token_type_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = ["self.wte.weight", "lm_head.weight", "fc_in.weight", "fc_in.bias"] + _row_parallel_weights = ["out_projection.weight", "fc_out.weight"] + + def load_weights(self, model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if ("attn.bias" in name or "attn.masked_bias" in name): + continue + param = state_dict[name] + if ("q_proj" in name or "k_proj" in name or "v_proj" in name): + shard_size = param.shape[0] + loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + + num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // num_heads + if ('q_proj.weight' in name or 'k_proj.weight' in name or 'v_proj.weight' in name): + loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif ('q_proj.bias' in name or 'k_proj.bias' in name or 'v_proj.bias' in name): + loaded_weight = loaded_weight.view(-1, 3, head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected weight name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_j_template.py b/vllm/model_executor/models/gpt_j_template.py new file mode 100644 index 000000000000..ad42b303b2a6 --- /dev/null +++ b/vllm/model_executor/models/gpt_j_template.py @@ -0,0 +1,58 @@ +try: + @register_loader("GPTJConfig") + class GPTJLoader(ModelLoader): + @property + def architecture_name(self): + return "GPTJForCausalLM" + + def get_model_spec(self, model): + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + model.config.n_layer, + model.config.n_head, + pre_norm=True, + activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function], + rotary_dim=model.config.rotary_dim, + rotary_interleave=False, + parallel_residual=True, + shared_layer_norm=True, + ) + + self.set_decoder( + spec.decoder, + model.transformer, + model.config.rotary_dim, + model.config.n_head, + ) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + + def set_decoder(self, spec, module, rotary_dim, num_heads): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.wte) + self.set_layer_norm(spec.layer_norm, module.ln_f) + + for layer_spec, layer in zip(spec.layer, module.h): + self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln_1) + + qw = layer.attn.q_proj.weight.numpy() + kw = layer.attn.k_proj.weight.numpy() + vw = layer.attn.v_proj.weight.numpy() + + qw = utils.permute_for_sliced_rotary(qw, num_heads, rotary_dim) + kw = utils.permute_for_sliced_rotary(kw, num_heads, rotary_dim) + + layer_spec.self_attention.linear[0].weight = np.concatenate((qw, kw, vw)) + self.set_linear(layer_spec.self_attention.linear[1], layer.attn.out_proj) + + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc_in) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc_out) +except: + pass \ No newline at end of file diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index cef67fde587f..a7a002435da4 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -98,7 +98,9 @@ def load_tensor_parallel_weights( shard_size * tensor_model_parallel_rank :shard_size * (tensor_model_parallel_rank + 1)] break - assert param.shape == loaded_weight.shape + assert param.shape == loaded_weight.shape, ( + f"{param_name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") param.data.copy_(loaded_weight) diff --git a/vllm/outputs.py b/vllm/outputs.py index fea521d12924..384ca020985d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -53,6 +53,7 @@ class RequestOutput: prompt: The prompt string of the request. prompt_token_ids: The token IDs of the prompt. outputs: The output sequences of the request. + finished: Whether the whole request is finished. """ def __init__( self, @@ -60,11 +61,13 @@ def __init__( prompt: str, prompt_token_ids: List[int], outputs: List[CompletionOutput], + finished: bool, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.outputs = outputs + self.finished = finished @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -95,13 +98,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Every sequence in the sequence group should have the same prompt. prompt = top_n_seqs[0].prompt prompt_token_ids = top_n_seqs[0].data.prompt_token_ids - return cls(seq_group.request_id, prompt, prompt_token_ids, outputs) + finished = seq_group.is_finished() + return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, + finished) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"outputs={self.outputs})") - - def finished(self) -> bool: - return all(output.finished() for output in self.outputs) + f"outputs={self.outputs}, " + f"finished={self.finished})") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index b7623ff58df6..dfe918995c92 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,7 @@ """Sampling parameters for text generation.""" from typing import List, Optional, Union +_SAMPLING_EPS = 1e-5 class SamplingParams: """Sampling parameters for text generation. @@ -71,7 +72,7 @@ def __init__( self._verify_args() if self.use_beam_search: self._verity_beam_search() - elif self.temperature == 0.0: + elif self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self._verify_greedy_sampling() @@ -106,9 +107,9 @@ def _verity_beam_search(self) -> None: if self.best_of == 1: raise ValueError("best_of must be greater than 1 when using beam " f"search. Got {self.best_of}.") - if self.temperature > 0.0: + if self.temperature > _SAMPLING_EPS: raise ValueError("temperature must be 0 when using beam search.") - if self.top_p < 1.0: + if self.top_p < 1.0 - _SAMPLING_EPS: raise ValueError("top_p must be 1 when using beam search.") if self.top_k != -1: raise ValueError("top_k must be -1 when using beam search.") @@ -117,7 +118,7 @@ def _verify_greedy_sampling(self) -> None: if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") - if self.top_p < 1.0: + if self.top_p < 1.0 - _SAMPLING_EPS: raise ValueError("top_p must be 1 when using greedy sampling.") if self.top_k != -1: raise ValueError("top_k must be -1 when using greedy sampling.") diff --git a/vllm/sequence.py b/vllm/sequence.py index 5fe847297cd5..083e6d03c330 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum): FINISHED_STOPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() @staticmethod def is_finished(status: "SequenceStatus") -> bool: @@ -20,6 +21,7 @@ def is_finished(status: "SequenceStatus") -> bool: SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_ABORTED, + SequenceStatus.FINISHED_IGNORED ] @staticmethod @@ -30,6 +32,8 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: finish_reason = "length" elif status == SequenceStatus.FINISHED_ABORTED: finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + finish_reason = "length" else: finish_reason = None return finish_reason diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/engine/tokenizer_utils.py b/vllm/transformers_utils/tokenizer.py similarity index 62% rename from vllm/engine/tokenizer_utils.py rename to vllm/transformers_utils/tokenizer.py index 1a0115e64c27..9971e8c7bb3c 100644 --- a/vllm/engine/tokenizer_utils.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,46 +1,51 @@ from typing import List, Tuple, Union -from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, +from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) from vllm.logger import init_logger logger = init_logger(__name__) -_MODEL_TYPES_WITH_SLOW_TOKENIZER = [] +# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( - model_name: str, + tokenizer_name: str, + tokenizer_mode: str = "auto", *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" - config = AutoConfig.from_pretrained(model_name) - if "open_llama" in model_name: - kwargs["use_fast"] = False - logger.info( - "OpenLLaMA models do not support the fast tokenizer. " - "Using the slow tokenizer instead.") - elif config.model_type == "llama" and getattr(kwargs, "use_fast", True): - # LLaMA fast tokenizer causes protobuf errors in some environments. - # However, we found that the below LLaMA fast tokenizer works well in - # most environments. - model_name = "hf-internal-testing/llama-tokenizer" - logger.info( - f"Using the LLaMA fast tokenizer in '{model_name}' to avoid " - "potential protobuf errors.") - elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: - if getattr(kwargs, "use_fast", False) == True: + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): raise ValueError( - f"Cannot use the fast tokenizer for {config.model_type} due to " - "bugs in the fast tokenizer.") - logger.info( - f"Using the slow tokenizer for {config.model_type} due to bugs in " - "the fast tokenizer. This could potentially lead to performance " - "degradation.") + "Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - return AutoTokenizer.from_pretrained(model_name, *args, **kwargs) + + if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): + logger.info( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer.") + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, + **kwargs) + except TypeError as e: + # The LLaMA tokenizer causes a protobuf error in some environments. + err_msg = ( + "Failed to load the tokenizer. If you are using a LLaMA-based " + f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer.") + raise RuntimeError(err_msg) from e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + return tokenizer def detokenize_incrementally( diff --git a/vllm/utils.py b/vllm/utils.py index 85fe18778244..eb686b64eafb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,4 +1,5 @@ import enum +from platform import uname import uuid import psutil @@ -36,3 +37,7 @@ def get_cpu_memory() -> int: def random_uuid() -> str: return str(uuid.uuid4().hex) + +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 30b4ec7d79cf..eb99bf759e83 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -5,6 +5,10 @@ from vllm import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.utils import in_wsl + +logger = init_logger(__name__) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -85,16 +89,22 @@ def allocate_cpu_cache(self) -> List[KVCache]: cpu_cache: List[KVCache] = [] key_block_shape = self.get_key_block_shape() value_block_shape = self.get_value_block_shape() + pin_memory = not in_wsl() + if not pin_memory: + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + logger.warn("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") for _ in range(self.num_layers): key_blocks = torch.empty( size=(self.num_cpu_blocks, *key_block_shape), dtype=self.dtype, - pin_memory=True, + pin_memory=pin_memory, ) value_blocks = torch.empty( size=(self.num_cpu_blocks, *value_block_shape), dtype=self.dtype, - pin_memory=True, + pin_memory=pin_memory, ) cpu_cache.append((key_blocks, value_blocks)) return cpu_cache diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2beafc670418..b97011206ee2 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -113,6 +113,8 @@ def profile_num_available_blocks( num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) torch.cuda.empty_cache() # Reset the seed to ensure that the random state is not affected by