diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 75e9cf6a6579..cbf6dda677c5 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -54,7 +54,7 @@ tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines echo '```' >> benchmark_results.md # if the agent binary is not found, skip uploading the results, exit 0 -if [ ! -f buildkite-agent ]; then +if [ ! -f /usr/bin/buildkite-agent ]; then exit 0 fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b48ef31bc416..6b12d19ba611 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,6 +37,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/.github/scripts/run-tests b/.github/scripts/run-tests index 9e18f96f339d..e64ea401b16c 100755 --- a/.github/scripts/run-tests +++ b/.github/scripts/run-tests @@ -113,6 +113,8 @@ do # need to be run with specific options if [[ "${TEST}" == *"kernels"* || "${TEST}" == *"samplers"* ]]; then CUDA_VISIBLE_DEVICES=0,1 pytest ${CC_PYTEST_FLAGS} --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? + elif [[ "${TEST}" == *"distributed/test_same_node"* ]]; then + VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 ${TEST} || LOCAL_SUCCESS=$? elif [[ "${TEST}" == *"distributed"* ]]; then CUDA_VISIBLE_DEVICES=0,1 pytest ${CC_PYTEST_FLAGS} --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? elif [[ "${TEST}" == *"test_models_logprobs"* ]]; then diff --git a/Dockerfile.neuron b/Dockerfile.neuron index fe42b4ef393f..010f23a14301 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt RUN cd /app/vllm \ && python3 -m pip install -U -r requirements-neuron.txt -ENV VLLM_BUILD_WITH_NEURON 1 +ENV VLLM_TARGET_DEVICE neuron RUN cd /app/vllm \ && pip install -e . \ && cd .. diff --git a/docs/source/automatic_prefix_caching/apc.rst b/docs/source/automatic_prefix_caching/apc.rst new file mode 100644 index 000000000000..0d70c74689bf --- /dev/null +++ b/docs/source/automatic_prefix_caching/apc.rst @@ -0,0 +1,110 @@ +.. _apc: + +Introduction +============ + +What is Automatic Prefix Caching +-------------------------------- + +Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part. + + +.. note:: + + Technical details on how vLLM implements APC are in the next page. + + + +Enabling APC in vLLM +-------------------- + +Set ``enable_prefix_caching=True`` in vLLM engine to enable APC. Here is an example: + +.. code-block:: python + + import time + from vllm import LLM, SamplingParams + + + # A prompt containing a large markdown table. The table is randomly generated by GPT-4. + LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ + | ID | Name | Age | Occupation | Country | Email | Phone Number | Address | + |-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| + | 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | + | 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON | + | 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK | + | 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW | + | 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ | + | 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE | + | 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY | + | 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC | + | 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK | + | 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC| + | 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ | + | 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE | + | 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA | + | 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB | + | 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK | + | 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD | + | 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ | + | 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE | + | 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA | + | 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON | + | 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK | + | 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA | + | 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ| + | 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE | + | 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO | + | 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC | + | 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK | + | 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA | + | 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | + | 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | + """ + + + def get_generation_time(llm, sampling_params, prompts): + # time the generation + start_time = time.time() + output = llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + # print the output and generation time + print(f"Output: {output[0].outputs[0].text}") + print(f"Generation time: {end_time - start_time} seconds.") + + + # set enable_prefix_caching=True to enable APC + llm = LLM( + model='lmsys/longchat-13b-16k', + enable_prefix_caching=True + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=100) + + # Querying the age of John Doe + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", + ) + + # Querying the age of Zack Blue + # This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again. + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", + ) + +Example workloads +----------------- + +We describe two example workloads, where APC can provide huge performance benefit: + +- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency. +- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency. + + +Limits +------ +APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused). diff --git a/docs/source/automatic_prefix_caching/details.md b/docs/source/automatic_prefix_caching/details.md new file mode 100644 index 000000000000..2d3214e28ed9 --- /dev/null +++ b/docs/source/automatic_prefix_caching/details.md @@ -0,0 +1,43 @@ +# Implementation + +The core idea of PagedAttention is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand. + +To automatically cache the KV cache, we utilize the following key observation: Each KV block can be uniquely identified by the tokens within the block and the tokens in the prefix before the block. + +``` + Block 1 Block 2 Block 3 + [A gentle breeze stirred] [the leaves as children] [laughed in the distance] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->| +``` + + +In the example above, the KV cache in the first block can be uniquely identified with the tokens “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the following one-to-one mapping: + +``` +hash(prefix tokens + block tokens) <--> KV Block +``` + +With this mapping, we can add another indirection in vLLM’s KV cache management. Previously, each sequence in vLLM maintained a mapping from their logical KV blocks to physical blocks. To achieve automatic caching of KV blocks, we map the logical KV blocks to their hash value and maintain a global hash table of all the physical blocks. In this way, all the KV blocks sharing the same hash value (e.g., shared prefix blocks across two requests) can be mapped to the same physical block and share the memory space. + + +This design achieves automatic prefix caching without the need of maintaining a tree structure among the KV blocks. More specifically, all of the blocks are independent of each other and can be allocated and freed by itself, which enables us to manages the KV cache as ordinary caches in operating system. + + +# Generalized Caching Policy + +Keeping all the KV blocks in a hash table enables vLLM to cache KV blocks from earlier requests to save memory and accelerate the computation of future requests. For example, if a new request shares the system prompt with the previous request, the KV cache of the shared prompt can directly be used for the new request without recomputation. However, the total KV cache space is limited and we have to decide which KV blocks to keep or evict when the cache is full. + +Managing KV cache with a hash table allows us to implement flexible caching policies. As an example, in current vLLM, we implement the following eviction policy: + +* When there are no free blocks left, we will evict a KV block with reference count (i.e., number of current requests using the block) equals 0. +* If there are multiple blocks with reference count equals to 0, we prioritize to evict the least recently used block (LRU). +* If there are multiple blocks whose last access time are the same, we prioritize the eviction of the block that is at the end of the longest prefix (i.e., has the maximum number of blocks before it). + +Note that this eviction policy effectively implements the exact policy as in [RadixAttention](https://lmsys.org/blog/2024-01-17-sglang/) when applied to models with full attention, which prioritizes to evict reference count zero and least recent used leaf nodes in the prefix tree. + +However, the hash-based KV cache management gives us the flexibility to handle more complicated serving scenarios and implement more complicated eviction policies beyond the policy above: + +- Multi-LoRA serving. When serving requests for multiple LoRA adapters, we can simply let the hash of each KV block to also include the LoRA ID the request is querying for to enable caching for all adapters. In this way, we can jointly manage the KV blocks for different adapters, which simplifies the system implementation and improves the global cache hit rate and efficiency. +- Multi-modal models. When the user input includes more than just discrete tokens, we can use different hashing methods to handle the caching of inputs of different modalities. For example, perceptual hashing for images to cache similar input images. diff --git a/docs/source/conf.py b/docs/source/conf.py index ee0f6c53bd1b..ca26dcec4bb5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,7 +92,7 @@ def setup(app): "vllm._C", "PIL", "numpy", - 'triton' + 'triton', "tqdm", "tensorizer", ] diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst new file mode 100644 index 000000000000..3e4d0362e3a0 --- /dev/null +++ b/docs/source/getting_started/debugging.rst @@ -0,0 +1,42 @@ +.. _debugging: + +Debugging Tips +=============== + +Debugging hang/crash issues +--------------------------- + +When an vLLM instance hangs or crashes, it is very difficult to debug the issue. But wait a minute, it is also possible that vLLM is doing something that indeed takes a long time: + +- Downloading a model: do you have the model already downloaded in your disk? If not, vLLM will download the model from the internet, which can take a long time. Be sure to check the internet connection. It would be better to download the model first using `huggingface cli `_ and then use the local path to the model. This way, you can isolate the issue. +- Loading the model from disk: if the model is large, it can take a long time to load the model from disk. Please take care of the location you store the model. Some clusters have shared filesystems across nodes, e.g. distributed filesystem or network filesystem, which can be slow. It would be better to store the model in a local disk. In addition, please also watch the CPU memory usage. When the model is too large, it might take much CPU memory, which can slow down the operating system because it needs to frequently swap memory between the disk and the memory. +- Tensor parallel inference: if the model is too large to fit in a single GPU, you might want to use tensor parallelism to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `the provided script `_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. + +If you already take care of the above issues, and the vLLM instance still hangs, with CPU and GPU utilization at near zero, it is likely that the vLLM instance is stuck somewhere. Here are some tips to help debug the issue: + +- Set the environment variable ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging. +- Set the environment variable ``export CUDA_LAUNCH_BLOCKING=1`` to know exactly which CUDA kernel is causing the trouble. +- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL. +- Set the environment variable ``export VLLM_TRACE_FUNCTION=1`` . All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs. **Note: it will generate a lot of logs and slow down the system. Only use it for debugging purposes.** + +With more logging, hopefully you can find the root cause of the issue. + +Here are some common issues that can cause hangs: + +- The network setup is incorrect. The vLLM instance cannot get the correct IP address. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``. +- Hardware/driver setup is incorrect. GPU communication cannot be established. You can run a sanity check script below to see if the GPU communication is working correctly. + +.. code-block:: python + + # save it as `test.py`` , and run it with `NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py` + # adjust `--nproc-per-node` to the number of GPUs you want to use. + import torch + import torch.distributed as dist + dist.init_process_group(backend="nccl") + data = torch.FloatTensor([1,] * 128).to(f"cuda:{dist.get_rank()}") + dist.all_reduce(data, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + value = data.mean().item() + assert value == dist.get_world_size() + +If the problem persists, feel free to open an `issue `_ on GitHub, with a detailed description of the issue, your environment, and the logs. diff --git a/docs/source/index.rst b/docs/source/index.rst index fad3c3b05b0c..807251d02974 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -66,6 +66,7 @@ Documentation getting_started/neuron-installation getting_started/cpu-installation getting_started/quickstart + getting_started/debugging getting_started/examples/examples_index .. toctree:: @@ -89,6 +90,7 @@ Documentation models/engine_args models/lora models/vlm + models/spec_decode models/performance .. toctree:: @@ -96,13 +98,20 @@ Documentation :caption: Quantization quantization/auto_awq + quantization/fp8 quantization/fp8_e5m2_kvcache quantization/fp8_e4m3_kvcache .. toctree:: :maxdepth: 1 + :caption: Automatic Prefix Caching + + automatic_prefix_caching/apc + automatic_prefix_caching/details + +.. toctree:: :caption: Developer Documentation - + dev/sampling_params dev/offline_inference/offline_index dev/engine/engine_index diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst new file mode 100644 index 000000000000..9fb62397b9aa --- /dev/null +++ b/docs/source/models/spec_decode.rst @@ -0,0 +1,77 @@ +.. _spec_decode: + +Speculative decoding in vLLM +============================ + +.. warning:: + Please note that speculative decoding in vLLM is not yet optimized and does + not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work + to optimize it is ongoing and can be followed in `this issue. `_ + +This document shows how to use `Speculative Decoding `_ with vLLM. +Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference. + +Speculating with a draft model +------------------------------ + +The following code configures vLLM to use speculative decoding with a draft model, speculating 5 tokens at a time. + +.. code-block:: python + + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + speculative_model="facebook/opt-125m", + num_speculative_tokens=5, + use_v2_block_manager=True, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +Speculating by matching n-grams in the prompt +--------------------------------------------- + +The following code configures vLLM to use speculative decoding where proposals are generated by +matching n-grams in the prompt. For more information read `this thread. `_ + +.. code-block:: python + + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + speculative_model="[ngram]", + num_speculative_tokens=5, + ngram_prompt_lookup_max=4, + use_v2_block_manager=True, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +Resources for vLLM contributors +------------------------------- +* `A Hacker's Guide to Speculative Decoding in vLLM `_ +* `What is Lookahead Scheduling in vLLM? `_ +* `Information on batch expansion. `_ +* `Dynamic speculative decoding `_ diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 33aa8246b2e6..5ab4157cb358 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -20,7 +20,8 @@ The following :ref:`engine arguments ` are specific to VLMs: Currently, the support for vision language models on vLLM has the following limitations: * Only single image input is supported per text prompt. - * Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means model output might not exactly match the huggingface implementation. + * Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means model output might not exactly match the HuggingFace implementation. + We are continuously improving user & developer experience for VLMs. Please raise an issue on GitHub if you have any feedback or feature requests. Offline Batched Inference diff --git a/docs/source/quantization/fp8.rst b/docs/source/quantization/fp8.rst new file mode 100644 index 000000000000..312a564595cc --- /dev/null +++ b/docs/source/quantization/fp8.rst @@ -0,0 +1,206 @@ +.. _fp8: + +FP8 +================== + +vLLM supports FP8 (8-bit floating point) computation using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x. Currently, only Hopper and Ada Lovelace GPUs are supported. Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy. + +Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM `_. + +The FP8 types typically supported in hardware have two distinct representations, each useful in different scenarios: + +- **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and ``nan``. +- **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf``, and ``nan``. The tradeoff for the increased dynamic range is lower precision of the stored values. + +Quick Start with Online Dynamic Quantization +-------------------------------------------- + +Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying ``--quantization="fp8"`` in the command line or setting ``quantization="fp8"`` in the LLM constructor. + +In this mode, all Linear modules (except for the final ``lm_head``) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode. + +.. code-block:: python + + from vllm import LLM + model = LLM("facebook/opt-125m", quantization="fp8") + # INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB + result = model.generate("Hello, my name is") + +.. warning:: + + Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. + +Offline Quantization +-------------------- + +For offline quantization to FP8, please install the `AutoFP8 library `_. + +.. code-block:: bash + + git clone https://github.com/neuralmagic/AutoFP8.git + pip install -e AutoFP8 + +This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed. + +Offline Quantization with Dynamic Activation Scaling Factors +------------------------------------------------------------ + +You can use AutoFP8 to produce checkpoints with their weights quantized to FP8 ahead of time and let vLLM handle calculating dynamic scales for the activations at runtime for maximum accuracy. You can enable this with the ``activation_scheme="dynamic"`` argument. + +.. warning:: + + Please note that although this mode doesn't give you better performance, it reduces memory footprint compared to online quantization. + +.. code-block:: python + + from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig + + pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" + quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-Dynamic" + + # Define quantization config with static activation scales + quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic") + # For dynamic activation scales, there is no need for calbration examples + examples = [] + + # Load the model, quantize, and save checkpoint + model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + model.quantize(examples) + model.save_quantized(quantized_model_dir) + +In the output of the above script, you should be able to see the quantized Linear modules (FP8DynamicLinear) replaced in the model definition. +Note that the ``lm_head`` Linear module at the end is currently skipped by default. + +.. code-block:: text + + LlamaForCausalLM( + (model): LlamaModel( + (embed_tokens): Embedding(128256, 4096) + (layers): ModuleList( + (0-31): 32 x LlamaDecoderLayer( + (self_attn): LlamaSdpaAttention( + (q_proj): FP8DynamicLinear() + (k_proj): FP8DynamicLinear() + (v_proj): FP8DynamicLinear() + (o_proj): FP8DynamicLinear() + (rotary_emb): LlamaRotaryEmbedding() + ) + (mlp): LlamaMLP( + (gate_proj): FP8DynamicLinear() + (up_proj): FP8DynamicLinear() + (down_proj): FP8DynamicLinear() + (act_fn): SiLU() + ) + (input_layernorm): LlamaRMSNorm() + (post_attention_layernorm): LlamaRMSNorm() + ) + ) + (norm): LlamaRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=128256, bias=False) + ) + Saving the model to Meta-Llama-3-8B-Instruct-FP8-Dynamic + +Your model checkpoint with quantized weights should be available at ``Meta-Llama-3-8B-Instruct-FP8/``. +We can see that the weights are smaller than the original BF16 precision. + +.. code-block:: bash + + ls -lh Meta-Llama-3-8B-Instruct-FP8-Dynamic/ + total 8.5G + -rw-rw-r-- 1 user user 869 Jun 7 14:43 config.json + -rw-rw-r-- 1 user user 194 Jun 7 14:43 generation_config.json + -rw-rw-r-- 1 user user 4.7G Jun 7 14:43 model-00001-of-00002.safetensors + -rw-rw-r-- 1 user user 3.9G Jun 7 14:43 model-00002-of-00002.safetensors + -rw-rw-r-- 1 user user 43K Jun 7 14:43 model.safetensors.index.json + -rw-rw-r-- 1 user user 296 Jun 7 14:43 special_tokens_map.json + -rw-rw-r-- 1 user user 50K Jun 7 14:43 tokenizer_config.json + -rw-rw-r-- 1 user user 8.7M Jun 7 14:43 tokenizer.json + +Finally, you can load the quantized model checkpoint directly in vLLM. + +.. code-block:: python + + from vllm import LLM + model = LLM(model="Meta-Llama-3-8B-Instruct-FP8-Dynamic/") + # INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB + result = model.generate("Hello, my name is") + +Offline Quantization with Static Activation Scaling Factors +----------------------------------------------------------- + +For the best inference performance, you can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument. + +.. code-block:: python + + from datasets import load_dataset + from transformers import AutoTokenizer + from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig + + pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" + quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" + + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + tokenizer.pad_token = tokenizer.eos_token + + # Load and tokenize 512 dataset samples for calibration of activation scales + ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) + examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] + examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") + + # Define quantization config with static activation scales + quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") + + # Load the model, quantize, and save checkpoint + model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + model.quantize(examples) + model.save_quantized(quantized_model_dir) + +Your model checkpoint with quantized weights and activations should be available at ``Meta-Llama-3-8B-Instruct-FP8/``. +Finally, you can load the quantized model checkpoint directly in vLLM. + +.. code-block:: python + + from vllm import LLM + model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/") + # INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB + result = model.generate("Hello, my name is") + +FP8 checkpoint structure explanation +----------------------------------------------------------- + +Here we detail the structure for the FP8 checkpoints. + +The following is necessary to be present in the model's ``config.json``: + +.. code-block:: text + + "quantization_config": { + "quant_method": "fp8", + "activation_scheme": "static" or "dynamic" + } + + +Each quantized layer in the state_dict will have these tensors: + +* If the config has ``"activation_scheme": "static"``: + +.. code-block:: text + + model.layers.0.mlp.down_proj.weight < F8_E4M3 + model.layers.0.mlp.down_proj.input_scale < F32 + model.layers.0.mlp.down_proj.weight_scale < F32 + +* If the config has ``"activation_scheme": "dynamic"``: + +.. code-block:: text + + model.layers.0.mlp.down_proj.weight < F8_E4M3 + model.layers.0.mlp.down_proj.weight_scale < F32 + + +Additionally, there can be `FP8 kv-cache scaling factors `_ contained within quantized checkpoints specified through the ``.kv_scale`` parameter present on the Attention Module, such as: + +.. code-block:: text + + model.layers.0.self_attn.kv_scale < F32 diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 4f36dca15d7d..b0c45dbf7026 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -3,11 +3,9 @@ Distributed Inference and Serving ================================= -vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with `Ray `_. To run distributed inference, install Ray with: +vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with either `Ray `_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray. -.. code-block:: console - - $ pip install ray +Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case. To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs: @@ -25,10 +23,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh $ --model facebook/opt-13b \ $ --tensor-parallel-size 4 -To scale vLLM beyond a single machine, start a `Ray runtime `_ via CLI before running vLLM: +To scale vLLM beyond a single machine, install and start a `Ray runtime `_ via CLI before running vLLM: .. code-block:: console + $ pip install ray + $ # On head node $ ray start --head diff --git a/setup.py b/setup.py index 7a1a0d37ebc3..61ed64d7791c 100644 --- a/setup.py +++ b/setup.py @@ -223,7 +223,7 @@ def _is_neuron() -> bool: subprocess.run(["neuron-ls"], capture_output=True, check=True) except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False - return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON + return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" def _is_cpu() -> bool: diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index c4434301201c..35d8808b7a69 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -55,9 +55,8 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): temperature=0.0) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices) == 1 + assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) @@ -69,8 +68,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): max_tokens=5, temperature=0.0, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 @pytest.mark.asyncio @@ -90,15 +88,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): logprobs=True, top_logprobs=5) assert chat_completion.id is not None - assert chat_completion.choices is not None and len( - chat_completion.choices) == 1 - assert chat_completion.choices[0].message is not None - assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.content[ - 0].top_logprobs is not None - assert len( - chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 - message = chat_completion.choices[0].message + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=13, total_tokens=23) + + message = choice.message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" messages.append({"role": "assistant", "content": message.content}) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py new file mode 100644 index 000000000000..4880bab79069 --- /dev/null +++ b/tests/distributed/test_same_node.py @@ -0,0 +1,11 @@ +import os + +import torch + +from vllm.distributed.parallel_state import is_in_the_same_node + +torch.distributed.init_process_group(backend="gloo") +test_result = is_in_the_same_node(torch.distributed.group.WORLD) + +expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" +assert test_result == expected, f"Expected {expected}, got {test_result}" diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 06bd0bf15c84..ac895e04dc1f 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -166,9 +166,10 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) @@ -179,8 +180,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 @pytest.mark.asyncio @@ -205,9 +205,9 @@ async def test_no_logprobs(server, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras + # just test 1 lora hereafter "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -263,7 +263,9 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, - logprobs=6, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, ) ... with pytest.raises( @@ -273,7 +275,9 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, - logprobs=6, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, stream=True, ) async for chunk in stream: @@ -286,55 +290,7 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # just test 1 lora hereafter - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_single_chat_session(server, client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=model_name, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert chat_completion.choices is not None and len( - chat_completion.choices) == 1 - assert chat_completion.choices[0].message is not None - assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.content[ - 0].top_logprobs is not None - assert len( - chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 + assert len(completion.choices[0].text) >= 0 @pytest.mark.asyncio @@ -389,7 +345,7 @@ async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.content is not None - assert len(choice.logprobs.content[0].top_logprobs) <= 1 + assert len(choice.logprobs.content[0].top_logprobs) == 0 @pytest.mark.asyncio @@ -417,11 +373,14 @@ async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.content is not None - assert len(choice.logprobs.content[0].top_logprobs) <= 6 + assert len(choice.logprobs.content[0].top_logprobs) == 5 @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, model_name: str): messages = [{ @@ -462,7 +421,51 @@ async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_single_chat_session(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=37, total_tokens=47) + + message = choice.message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( "model_name", [MODEL_NAME, "zephyr-lora"], ) @@ -748,8 +751,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): logit_bias={str(token_id): 100}, seed=42, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 response_tokens = tokenizer(completion.choices[0].text, add_special_tokens=False)["input_ids"] expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), @@ -796,9 +798,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 + assert len(completion.choices) == 3 for i in range(3): - assert completion.choices[i].text is not None output_json = json.loads(completion.choices[i].text) jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) @@ -865,9 +866,8 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 + assert len(completion.choices) == 3 for i in range(3): - assert completion.choices[i].text is not None assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None @@ -924,7 +924,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 2 + assert len(completion.choices) == 2 for i in range(2): assert completion.choices[i].text in TEST_CHOICE @@ -1026,12 +1026,14 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, top_logprobs=5, extra_body=dict(guided_choice=TEST_CHOICE, guided_decoding_backend=guided_decoding_backend)) + + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.content is not None top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs # -9999.0 is the minimum logprob returned by OpenAI - assert all( - isinstance(token.logprob, float) and token.logprob >= -9999.0 - for token in top_logprobs) + for item in top_logprobs: + assert item.logprob >= -9999.0, f"Failed (top_logprobs={top_logprobs})" @pytest.mark.asyncio @@ -1233,6 +1235,8 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): response_format={"type": "json_object"}) content = resp.choices[0].message.content + assert content is not None + loaded = json.loads(content) assert loaded == {"result": 2}, loaded @@ -1360,8 +1364,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt - assert (completion.choices[0].text is not None - and re.search(r"^" + prompt_text, completion.choices[0].text)) + assert re.search(r"^" + prompt_text, completion.choices[0].text) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 @@ -1402,32 +1405,32 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): ) async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, model_name: str): - input = [ + input_texts = [ "The chef prepared a delicious meal.", ] # test single embedding embeddings = await client.embeddings.create( model=model_name, - input=input, + input=input_texts, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 9 assert embeddings.usage.total_tokens == 9 # test using token IDs - input = [1, 1, 1, 1, 1] + input_tokens = [1, 1, 1, 1, 1] embeddings = await client.embeddings.create( model=model_name, - input=input, + input=input_tokens, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 5 @@ -1442,29 +1445,29 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, model_name: str): # test List[str] - inputs = [ + input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", "Stars twinkle brightly in the night sky." ] embeddings = await client.embeddings.create( model=model_name, - input=inputs, + input=input_texts, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data) == 3 assert len(embeddings.data[0].embedding) == 4096 # test List[List[int]] - inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] embeddings = await client.embeddings.create( model=model_name, - input=inputs, + input=input_tokens, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data) == 4 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 17 diff --git a/tests/models/test_compressed.py b/tests/models/test_compressed.py index cac6addafe5e..7bd9ae940984 100644 --- a/tests/models/test_compressed.py +++ b/tests/models/test_compressed.py @@ -6,6 +6,8 @@ Run `pytest tests/models/test_compressed.py`. """ +import gc + import pytest from tests.models.utils import check_logprobs_close @@ -40,6 +42,8 @@ def test_models( example_prompts, max_tokens, num_logprobs) del sparse_model + gc.collect() + dense_model = vllm_runner(model_name=model_name, sparsity=None, dtype=dtype, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d060e265848..f8a6de54653c 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -77,7 +77,11 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, + # For now use ray for the distributed back-end, since + # we rely on the use of engine_use_ray=True to avoid + # reinitializing CUDA in the same process (driver worker) engine_use_ray=True, + distributed_executor_backend="ray", disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index a95aa84978a4..70d789e97c12 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -214,9 +214,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): temperature=0.0) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices) == 1 + assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) diff --git a/tests/test_config.py b/tests/test_config.py index 7cbdaeca9c4d..6c8af9d7966b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -63,8 +63,9 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW -def test_rope_scaling(): +def test_rope_customization(): TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + TEST_ROPE_THETA = 16_000_000.0 LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} llama_model_config = ModelConfig( @@ -76,6 +77,7 @@ def test_rope_scaling(): seed=0, ) assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000 assert llama_model_config.max_model_len == 8192 llama_model_config = ModelConfig( @@ -86,9 +88,12 @@ def test_rope_scaling(): dtype="float16", seed=0, rope_scaling=TEST_ROPE_SCALING, + rope_theta=TEST_ROPE_THETA, ) assert getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING + assert getattr(llama_model_config.hf_config, "rope_theta", + None) == TEST_ROPE_THETA assert llama_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( diff --git a/vllm/config.py b/vllm/config.py index 72ddc8fda805..4453d0fc9851 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -93,6 +93,7 @@ def __init__( revision: Optional[str] = None, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -102,7 +103,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 5, + max_logprobs: int = 20, disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, @@ -115,6 +116,7 @@ def __init__( self.revision = revision self.code_revision = code_revision self.rope_scaling = rope_scaling + self.rope_theta = rope_theta # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision @@ -136,7 +138,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling) + code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len( @@ -633,9 +635,25 @@ def __init__( f"'{self.distributed_executor_backend}'.") if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + from torch.cuda import device_count + from vllm.executor import ray_utils + backend = "mp" ray_found = ray_utils.ray is not None - self.distributed_executor_backend = "ray" if ray_found else "mp" + if device_count() < self.world_size: + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference") + backend = "ray" + elif ray_found: + from ray.util import get_current_placement_group + if self.placement_group or get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.info("Defaulting to use %s for distributed inference", + backend) self._verify_args() diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 4a0e19bc0c15..bbc2284f8a36 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,7 +10,7 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( - get_local_rank, get_tensor_model_parallel_cpu_group) + get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) from vllm.logger import init_logger try: @@ -113,6 +113,13 @@ def __init__(self, assert dist.get_backend(group) != dist.Backend.NCCL, ( "CustomAllreduce should be attached to a non-NCCL group.") + if not is_in_the_same_node(group): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) if world_size == 1: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0ebd7a15eab9..b6d1eeff0978 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,6 +3,8 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" +import contextlib +from multiprocessing import resource_tracker, shared_memory from typing import List, Optional import torch @@ -376,3 +378,68 @@ def destroy_model_parallel(): _PP_DEVICE_GROUP = None global _PP_GLOBAL_RANKS _PP_GLOBAL_RANKS = None + + +def is_in_the_same_node(pg: ProcessGroup): + """ + This is a collective operation that checks if all processes in the group + are in the same node. It tests if all processes are attached to the same + memory system (shared access to shared memory). + """ + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "is_in_the_same_node should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == 0: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[0], + group=pg) + is_in_the_same_node[0] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[0], + group=pg) + name = recv[0] + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == 0: + if shm: + shm.unlink() + else: + if shm: + # fix to https://stackoverflow.com/q/62748654/9191338 + resource_tracker.unregister( + shm._name, "shared_memory") # type: ignore[attr-defined] + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return is_in_the_same_node.sum().item() == world_size diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5295c3db32a0..e3276e1a4b58 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -50,11 +50,12 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_logprobs: int = 5 # OpenAI default value + max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None # UPSTREAM SYNC: keep sparsity argument @@ -72,7 +73,7 @@ class EngineArgs: fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype = 'auto' + lora_dtype: str = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False @@ -415,6 +416,12 @@ def add_cli_args( type=json.loads, help='RoPE scaling configuration in JSON format. ' 'For example, {"type":"dynamic","factor":2.0}') + parser.add_argument('--rope-theta', + default=None, + type=float, + help='RoPE theta. Use with `rope_scaling`. In ' + 'some cases, changing the RoPE theta improves the ' + 'performance of the scaled model.') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -645,6 +652,7 @@ def create_engine_config(self, ) -> EngineConfig: revision=self.revision, code_revision=self.code_revision, rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 30e963f0d407..a1c77e04a515 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -162,7 +162,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "rope_scaling=%r, tokenizer_revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, sparsity=%s, " @@ -177,6 +177,7 @@ def __init__( model_config.tokenizer_mode, model_config.revision, model_config.rope_scaling, + model_config.rope_theta, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 761e4ddd8271..8512ff83e41c 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -78,7 +78,7 @@ def process_outputs(self, sequence_group: SequenceGroup, # Since there's only one sequence per sequence group, we can take the # first sample. - samples = [outputs[step].samples[0] for step in range(len(outputs))] + samples = [output.samples[0] for output in outputs] # -1 means the output token is not valid (eg. due to spec decode # rejecting tokens). diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9424ccc959d1..3b56ad63f375 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -322,9 +322,9 @@ def check_logprobs(cls, data): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) - elif not 0 <= data["top_logprobs"] <= 20: + elif data["top_logprobs"] < 0: raise ValueError( - "`top_logprobs` must be a value in the interval [0, 20].") + "`top_logprobs` must be a value a positive value.") return data @@ -478,9 +478,8 @@ def check_guided_decoding_count(cls, data): @classmethod def check_logprobs(cls, data): if "logprobs" in data and data[ - "logprobs"] is not None and not 0 <= data["logprobs"] <= 5: - raise ValueError(("if passed, `logprobs` must be a value", - " in the interval [0, 5].")) + "logprobs"] is not None and not data["logprobs"] >= 0: + raise ValueError("if passed, `logprobs` must be a positive value.") return data @model_validator(mode="before") @@ -514,7 +513,8 @@ class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): @@ -613,7 +613,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None @@ -636,7 +636,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dae60e4ec99f..7cd434fe0d27 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -373,13 +373,15 @@ async def chat_completion_stream_generator( continue delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ + out_logprobs = output.logprobs[ previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs: + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, ) else: @@ -490,12 +492,13 @@ async def chat_completion_full_generator( role = self.get_chat_request_role(request) for output in final_res.outputs: token_ids = output.token_ids - top_logprobs = output.logprobs + out_logprobs = output.logprobs - if request.logprobs: + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( token_ids=token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, ) else: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c3c40f2b97d1..64671e21a724 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionLogProbs, CompletionRequest, @@ -16,7 +17,6 @@ CompletionResponseStreamChoice, CompletionStreamResponse, UsageInfo) -# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -221,7 +221,7 @@ async def completion_stream_generator( # only return the prompt delta_text = res.prompt delta_token_ids = res.prompt_token_ids - top_logprobs = res.prompt_logprobs + out_logprobs = res.prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): @@ -229,7 +229,7 @@ async def completion_stream_generator( delta_text = res.prompt + output.text delta_token_ids = (res.prompt_token_ids + output.token_ids) - top_logprobs = res.prompt_logprobs + (output.logprobs + out_logprobs = res.prompt_logprobs + (output.logprobs or []) has_echoed[i] = True else: @@ -237,13 +237,15 @@ async def completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] delta_token_ids = output.token_ids[ previous_num_tokens[i]:] - top_logprobs = output.logprobs[previous_num_tokens[ + out_logprobs = output.logprobs[previous_num_tokens[ i]:] if output.logprobs else None if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, initial_text_offset=len(previous_texts[i]), ) @@ -325,25 +327,23 @@ def request_output_to_completion_response( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: token_ids = prompt_token_ids - top_logprobs = prompt_logprobs + out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = (prompt_logprobs + output.logprobs + out_logprobs = (prompt_logprobs + output.logprobs if request.logprobs is not None else None) output_text = prompt_text + output.text else: token_ids = output.token_ids - top_logprobs = output.logprobs + out_logprobs = output.logprobs output_text = output.text if request.logprobs is not None: - assert top_logprobs is not None, ( - "top_logprobs must be provided when logprobs " - "is requested") + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_completion_logprobs( token_ids=token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, ) else: diff --git a/vllm/envs.py b/vllm/envs.py index b140aa6d658e..f0513b9af276 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -33,7 +33,6 @@ VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None - VLLM_BUILD_WITH_NEURON: bool = False VLLM_USE_PRECOMPILED: bool = False VLLM_INSTALL_PUNICA_KERNELS: bool = False CMAKE_BUILD_TYPE: Optional[str] = None @@ -63,10 +62,6 @@ "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), - # If set, vllm will build with Neuron support - "VLLM_BUILD_WITH_NEURON": - lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)), - # If set, vllm will use precompiled binaries (*.so) "VLLM_USE_PRECOMPILED": lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index bd1cac2ab9b5..99c9e52034cc 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): """Python multiprocessing-based multi-GPU executor""" def _init_executor(self) -> None: - assert ( - not self.speculative_config - ), "Speculative decoding not yet supported for MultiProcGPU backend." - # Create the parallel GPU workers. world_size = self.parallel_config.tensor_parallel_size @@ -46,6 +42,7 @@ def _init_executor(self) -> None: if world_size == 1: self.workers = [] + self.worker_monitor = None else: result_handler = ResultHandler() self.workers = [ @@ -127,7 +124,8 @@ def _run_workers( def check_health(self) -> None: """Raises an error if engine is unhealthy.""" - if not self.worker_monitor.is_alive(): + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): raise RuntimeError("Worker processes are not running") def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 62887533f5c2..28c8e8699f08 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future], future.set_result(result) return loop = future.get_loop() - if result.exception is not None: - loop.call_soon_threadsafe(future.set_exception, result.exception) - else: - loop.call_soon_threadsafe(future.set_result, result.value) + if not loop.is_closed(): + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) class ResultHandler(threading.Thread): diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1f2ab7e2870c..a80703155c0b 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -306,8 +306,10 @@ def _create_output( # Fill in the first k columns of the output tensor using masks and data # tensors. - output[:, :k] = torch.where(accepted_mask, draft_token_ids, - -torch.ones_like(draft_token_ids)) + torch.where(accepted_mask, + draft_token_ids, + -torch.ones_like(draft_token_ids), + out=output) # Fill the last column. # We check output directly as accepted may have True values inconsistent diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 7792f3a3425c..1bde042086f0 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -80,7 +80,7 @@ def score_proposals( target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( - seq_group_metadata_list=target_seq_group_metadata_list, )) + seq_group_metadata_list=target_seq_group_metadata_list)) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] @@ -140,8 +140,7 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, - target_sampler_output: List[SamplerOutput], + self, contracted_bs: int, target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, num_scoring_tokens: int, non_spec_indices: List[int], spec_indices: List[int], k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -167,30 +166,16 @@ def _contract_batch( non_spec_expanded_bs, _ = non_spec_target_token_ids.shape spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs - target_token_ids = target_token_ids.squeeze().reshape( - spec_expanded_bs, k + 1) - target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, - self._vocab_size) - target_logprobs = target_logprobs.squeeze().reshape( - spec_expanded_bs, k + 1, self._vocab_size) - - all_tokens = torch.full(size=(contracted_bs, k + 1), - fill_value=-1, - device=self._device, - dtype=torch.long) - all_probs = torch.zeros(contracted_bs, - k + 1, - self._vocab_size, - device=self._device, - dtype=torch.float32) - all_logprobs = torch.full(size=( - contracted_bs, - k + 1, - self._vocab_size, - ), - fill_value=-float("inf"), - device=self._device, - dtype=torch.float32) + target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) + target_probs = target_probs.reshape(*target_token_ids.shape, + self._vocab_size) + target_logprobs = target_logprobs.reshape(target_probs.shape) + + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 45d9d5735efc..8b147c80690d 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,6 +3,7 @@ import torch +from vllm.config import SpeculativeConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. """ assert "speculative_config" in kwargs - speculative_config = kwargs.get("speculative_config") + speculative_config: SpeculativeConfig = kwargs.get("speculative_config") assert speculative_config is not None target_worker = Worker(*args, **kwargs) @@ -109,12 +110,11 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) - return SpecDecodeWorker( - proposer_worker, - scorer_worker, - disable_by_batch_size=disable_by_batch_size, - rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, )) + return SpecDecodeWorker(proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens)) def __init__( self, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index fdef2833a399..278db94bfc0d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -148,7 +148,8 @@ def _split_by_proposal_len( nonzero_proposal_len_indices, ) - def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + @staticmethod + def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, nonzero_proposal_len_indices, transposed): """Remove sequences from nonzero_proposal_len_indices and reset their proposal_len to 0 the draft worker does not provide a proposal @@ -207,7 +208,7 @@ def _merge_outputs( self, batch_size: int, proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], + maybe_sampler_output: Optional[List[SamplerOutput]], proposal_lens: List[int], nonzero_proposal_len_indices: List[int], sampler_transposed: bool, @@ -218,25 +219,19 @@ def _merge_outputs( if maybe_sampler_output is None: # If no speculative tokens, the sampler output will be None. # In this case we return empty proposals. - proposal_tokens = torch.full( - size=( - batch_size, - proposal_len, - ), - fill_value=-1, - dtype=torch.long, - device=self._device, - ) - proposal_probs = torch.zeros( - batch_size, - proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device, - ) - proposal_lens_tensor = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) + proposal_tokens = torch.tensor(-1, + dtype=torch.long, + device=self._device).expand( + batch_size, proposal_len) + proposal_probs = torch.tensor(0, + dtype=torch.float32, + device=self._device).expand( + batch_size, proposal_len, + self._vocab_size) + proposal_lens_tensor = torch.tensor(0, + dtype=torch.long, + device=self._device).expand( + len(proposal_lens)) return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output @@ -246,18 +241,14 @@ def _merge_outputs( # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - entire_proposal_tokens = torch.full( + entire_proposal_tokens = proposal_tokens.new_full( size=(batch_size, *proposal_tokens.shape[1:]), fill_value=-1, - dtype=torch.long, - device=self._device, ) entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros( + entire_proposal_probs = proposal_probs.new_zeros( batch_size, *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device, ) entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 4dc6c49eb58d..60ed9d39eb8d 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,12 +1,11 @@ from contextlib import contextmanager -from itertools import chain from typing import Dict, List, Tuple import torch from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) + SequenceOutput) SeqId = int @@ -16,11 +15,7 @@ def get_all_seq_ids( """Given a list of SequenceGroupMetadata, create a list of all sequence ids. """ - return list( - chain.from_iterable([ - seq_group_metadata.seq_data.keys() - for seq_group_metadata in seq_group_metadata_list - ])) + return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] def get_all_num_logprobs( @@ -68,7 +63,7 @@ def create_sequence_group_output( seq_id: SeqId, topk_token_ids: List[int], topk_logprobs: List[float], -) -> SequenceGroupOutput: +) -> CompletionSequenceGroupOutput: """Create a SequenceGroupOutput given the sampling results. Args: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 970645987885..ada84018212a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Type from transformers import PretrainedConfig @@ -9,7 +9,7 @@ logger = init_logger(__name__) -_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, @@ -23,7 +23,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None) -> PretrainedConfig: + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None) -> PretrainedConfig: try: if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -50,10 +51,12 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) - if rope_scaling is not None: - logger.info("Updating rope_scaling from %r to %r", - getattr(config, "rope_scaling", None), rope_scaling) - config.update({"rope_scaling": rope_scaling}) + for key, value in [("rope_scaling", rope_scaling), + ("rope_theta", rope_theta)]: + if value is not None: + logger.info("Updating %s from %r to %r", key, + getattr(config, key, None), value) + config.update({key: value}) return config @@ -68,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig): assert hasattr(config.text_config, "num_attention_heads") return config.text_config else: - return config \ No newline at end of file + return config diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7879a5de5b7b..99b12293a024 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -527,16 +527,6 @@ def _prepare_model_input( ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) @@ -544,11 +534,6 @@ def _prepare_model_input( dtype=torch.int32, device=self.device) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -601,6 +586,21 @@ def _prepare_model_input( seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor,