diff --git a/docs/models/extensions/tensorizer.md b/docs/models/extensions/tensorizer.md index b6feb405c6ca..c160055d6fb5 100644 --- a/docs/models/extensions/tensorizer.md +++ b/docs/models/extensions/tensorizer.md @@ -6,11 +6,109 @@ title: Loading models with CoreWeave's Tensorizer vLLM supports loading models with [CoreWeave's Tensorizer](https://docs.coreweave.com/coreweave-machine-learning-and-ai/inference/tensorizer). vLLM model tensors that have been serialized to disk, an HTTP/HTTPS endpoint, or S3 endpoint can be deserialized at runtime extremely quickly directly to the GPU, resulting in significantly -shorter Pod startup times and CPU memory usage. Tensor encryption is also supported. +lower Pod startup times and CPU memory usage. Tensor encryption is also supported. -For more information on CoreWeave's Tensorizer, please refer to -[CoreWeave's Tensorizer documentation](https://github.com/coreweave/tensorizer). For more information on serializing a vLLM model, as well a general usage guide to using Tensorizer with vLLM, see -the [vLLM example script](https://docs.vllm.ai/en/latest/examples/tensorize_vllm_model.html). +vLLM fully integrates Tensorizer in its model loading machinery. The +following will give a brief overview on how to get started with using +Tensorizer on vLLM. + +## The basics +To load a model using Tensorizer, it first needs to be serialized by Tensorizer. +The example script in [examples/others/tensorize_vllm_model.py](https://github.com/vllm-project/vllm/blob/main/examples/others/tensorize_vllm_model.py) +takes care of this process. + +The `TensorizerConfig` class is used to customize Tensorizer's behaviour, +defined in [vllm/model_executor/model_loader/tensorizer.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/tensorizer.py). +It is passed to any serialization or deserialization operation. +When loading with Tensorizer using the vLLM +library rather than through a model-serving entrypoint, it gets passed to +the `LLM` entrypoint class directly. Here's an example of loading a model +saved at `"s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors"`: + +```python +from vllm import LLM +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + +path_to_tensors = "s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors" + +model_ref = "facebook/opt-125m" +tensorizer_config = TensorizerConfig( + tensorizer_uri=path_to_tensors +) + +llm = LLM( + model_ref, + load_format="tensorizer", + model_loader_extra_config=tensorizer_config, +) +``` + +However, the above code will not function until you have successfully +serialized the model tensors with Tensorizer to get the `model.tensors` +file shown. The following section walks through an end-to-end example +of serializing `facebook/opt-125m` with the example script, +and then loading it for inference. + +## Saving a vLLM model with Tensorizer +To save a model with Tensorizer, call the example script with the necessary +CLI arguments. The docstring for the script itself explains the CLI args +and how to use it properly in great detail, and we'll use one of the +examples from the docstring directly, assuming we want to save our model at +our S3 bucket example `s3://my-bucket`: + +```bash +python examples/others/tensorize_vllm_model.py \ + --model facebook/opt-125m \ + serialize \ + --serialized-directory s3://my-bucket \ + --suffix v1 +``` + +This saves the model tensors at +`s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors`, as well as all other +artifacts needed to load the model at that same directory. + +## Serving the model using Tensorizer +Once the model is serialized where you want it, all that is needed is to +pass that directory with the model artifacts to `vllm serve` in the case above, +one can simply do: + +```bash +vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 --load-format=tensorizer +``` + +Please note that object storage authentication is still required. +To authenticate, Tensorizer searches for a `~/.s3cfg` configuration file +([s3cmd](https://s3tools.org/kb/item14.htm)'s format), +or `~/.aws/config` and `~/.aws/credentials` files (`boto3` and the +`aws` CLI's format), but authentication can also be configured using +[any normal `boto3` environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables), +such as `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, +and `AWS_ENDPOINT_URL_S3`. + +If only the `model.tensors` file exists, you can load exclusively that file +by using the `--model-loader-extra-config` CLI parameter, which expects a +JSON string with all the kwargs for a `TensorizerConfig` configuration object. + +```bash +#!/bin/sh + +MODEL_LOADER_EXTRA_CONFIG='{ + "tensorizer_uri": "s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors", + "stream_kwargs": {"force_http": false}, + "deserialization_kwargs": {"verify_hash": true, "num_readers": 8} +}' + +vllm serve facebook/opt-125m \ + --load-format=tensorizer \ + --model-loader-extra-config="$MODEL_LOADER_EXTRA_CONFIG" +``` + +Note in this case, if the directory to the model artifacts at +`s3://my-bucket/vllm/facebook/opt-125m/v1/` doesn't have all necessary model +artifacts to load, you'll want to pass `facebook/opt-125m` as the model tag like +it was done in the example script above. In this case, vLLM will take care of +resolving the other model artifacts by pulling them from HuggingFace Hub. !!! note Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`. diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index 175777630833..f6e68fc5c813 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -14,8 +14,13 @@ TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model, + tensorizer_kwargs_arg ) from vllm.utils import FlexibleArgumentParser +import logging + +logger = logging.getLogger() + # yapf conflicts with isort for this docstring # yapf: disable @@ -140,7 +145,7 @@ def parse_args(): ) - subparsers = parser.add_subparsers(dest='command') + subparsers = parser.add_subparsers(dest='command', required=True) serialize_parser = subparsers.add_parser( 'serialize', help="Serialize a model to `--serialized-directory`") @@ -170,6 +175,14 @@ def parse_args(): "where `suffix` is given by `--suffix` or a random UUID if not " "provided.") + serialize_parser.add_argument( + "--serialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorSerializer` during " + "serialization.")) + serialize_parser.add_argument( "--keyfile", type=str, @@ -195,11 +208,27 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) + deserialize_parser.add_argument( + "--deserialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorDeserializer` during " + "deserialization.")) + TensorizerArgs.add_cli_args(deserialize_parser) return parser.parse_args() - +def merge_extra_config_with_tensorizer_config(extra_cfg: dict, + cfg: TensorizerConfig): + for k, v in extra_cfg.items(): + if hasattr(cfg, k): + setattr(cfg, k, v) + logger.info( + f"Updating TensorizerConfig with {k} from " + f"--model-loader-extra-config provided" + ) def deserialize(): if args.lora_path: @@ -266,13 +295,10 @@ def deserialize(): else: keyfile = None + extra_config = {} if args.model_loader_extra_config: - config = json.loads(args.model_loader_extra_config) - tensorizer_args = \ - TensorizerConfig(**config)._construct_tensorizer_args() - tensorizer_args.tensorizer_uri = args.path_to_tensors - else: - tensorizer_args = None + extra_config = json.loads(args.model_loader_extra_config) + if args.command == "serialize": eng_args_dict = {f.name: getattr(args, f.name) for f in @@ -293,21 +319,24 @@ def deserialize(): tensorizer_config = TensorizerConfig( tensorizer_uri=model_path, encryption_keyfile=keyfile, - **credentials) + serialization_kwargs=args.serialization_kwargs or {}, + **credentials + ) if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorize_lora_adapter(args.lora_path, tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": - if not tensorizer_args: - tensorizer_config = TensorizerConfig( - tensorizer_uri=args.path_to_tensors, - encryption_keyfile = keyfile, - **credentials - ) - deserialize() - else: - raise ValueError("Either serialize or deserialize must be specified.") + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + encryption_keyfile=keyfile, + deserialization_kwargs=args.deserialization_kwargs or {}, + **credentials + ) + + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) + deserialize() \ No newline at end of file diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index f1ab7223048d..61fc92316f9c 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -59,12 +59,18 @@ def tensorize_model_and_lora(tmp_dir, model_uri): def server(model_uri, tensorize_model_and_lora): model_loader_extra_config = { "tensorizer_uri": model_uri, + "stream_kwargs": { + "force_http": False, + }, + "deserialization_kwargs": { + "verify_hash": True, + "num_readers": 8, + } } ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--device", "cuda", - "--model-loader-extra-config", + "--load-format", "tensorizer", "--model-loader-extra-config", json.dumps(model_loader_extra_config), "--enable-lora" ] diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 580992dea53d..13d7b4e7b7aa 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import subprocess import sys +import json from typing import Union import pytest @@ -210,7 +211,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' ], check=True, capture_output=True, diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index ce8689f5b89c..3ec0d1dbfe3e 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,9 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest +import os +from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.utils import get_distributed_init_method, get_open_port, get_ip + +from vllm.v1.engine.core import EngineCore +from vllm.v1.executor.abstract import UniProcExecutor +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_worker import Worker +from vllm.worker.worker_base import WorkerWrapperBase + +@pytest.fixture(autouse=True) +def allow_insecure_serialization(monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.fixture(autouse=True) def cleanup(): @@ -14,3 +30,54 @@ def cleanup(): def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") return config + + +def assert_from_collective_rpc(engine: LLM, + closure: Callable, + closure_kwargs: dict): + res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) + return all(res) + + +# This is an object pulled from tests/v1/engine/test_engine_core.py +# Modified to strip the `load_model` method from its `_init_executor` +# method. It's purely used as a dummy utility to run methods that test +# Tensorizer functionality +class DummyExecutor(UniProcExecutor): + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase( + vllm_config=self.vllm_config, + rpc_rank=0 + ) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port() + ) + local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":" + ) + if len(device_info) > 1: + local_rank = int(device_info[1]) + rank = 0 + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs],)) + self.collective_rpc("init_device") + + @property + def max_concurrent_batches(self) -> int: + return 2 + + def shutdown(self): + if hasattr(self, 'thread_pool'): + self.thread_pool.shutdown(wait=False) \ No newline at end of file diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b6286e148397..bf572b656af0 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,37 +1,63 @@ # SPDX-License-Identifier: Apache-2.0 - +import asyncio +import contextlib +import copy +import functools import gc +import json import os import pathlib import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor, Future +from typing import Callable, Type, Union, Any from unittest.mock import MagicMock, patch +from dataclasses import dataclass import pytest import torch -from vllm import SamplingParams +from vllm import SamplingParams, LLM from vllm.engine.arg_utils import EngineArgs -# yapf conflicts with isort for this docstring + # yapf: disable -from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, - TensorSerializer, - is_vllm_tensorized, - load_with_tensorizer, - open_stream, - tensorize_vllm_model) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + tensorize_vllm_model +) +from vllm.model_executor.model_loader.tensorizer_loader import ( + BLACKLISTED_TENSORIZER_ARGS) +import vllm.model_executor.model_loader.tensorizer +from .conftest import DummyExecutor # yapf: enable -from vllm.utils import PlaceholderModule +from vllm.utils import ( + PlaceholderModule, get_distributed_init_method, + get_open_port, get_ip, +) +from .conftest import assert_from_collective_rpc from ..utils import VLLM_PATH try: + import tensorizer from tensorizer import EncryptionParams except ImportError: tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment] EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") + +class TensorizerCaughtError(Exception): + pass + + EXAMPLES_PATH = VLLM_PATH / "examples" +pytest_plugins = "pytest_asyncio", + prompts = [ "Hello, my name is", "The president of the United States is", @@ -46,6 +72,32 @@ os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") +def patch_init_and_catch_error(self, obj, method_name, expected_error: Type[Exception]): + original = getattr(obj, method_name, None) + if original is None: + raise ValueError("Method '{}' not found.".format(method_name)) + + def wrapper(*args, **kwargs): + try: + return original(*args, **kwargs) + except expected_error: + raise TensorizerCaughtError + + setattr(obj, method_name, wrapper) + + self.load_model() + + +def assert_specific_tensorizer_error_is_raised( + executor, + obj: Any, + method_name: str, + expected_error: Type[Exception], + ): + with pytest.raises(TensorizerCaughtError): + executor.collective_rpc(patch_init_and_catch_error, + args=(obj, method_name, expected_error,)) + def is_curl_installed(): try: subprocess.check_call(['curl', '--version']) @@ -264,7 +316,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) - args = EngineArgs(model=model_ref, device="cuda") + args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: outputs = vllm_model.generate(prompts, sampling_params) @@ -280,3 +332,232 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): # noqa: E501 assert outputs == deserialized_outputs + +def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + llm = LLM( + model=model_ref, + ) + + + def serialization_test(self, *args, **kwargs): + # This is performed in the ephemeral worker process, so monkey-patching + # will actually work, and cleanup is guaranteed so don't + # need to reset things + + original_dict = serialization_params + to_compare = {} + + original = tensorizer.serialization.TensorSerializer.__init__ + + def tensorizer_serializer_wrapper(self, *args, **kwargs): + nonlocal to_compare + to_compare = kwargs.copy() + return original(self, *args, **kwargs) + + tensorizer.serialization.TensorSerializer.__init__ = tensorizer_serializer_wrapper + + tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) + self.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + return to_compare | original_dict == to_compare + + kwargs = { + "tensorizer_config": config.to_dict() + } + + assert assert_from_collective_rpc(llm, serialization_test, kwargs) + + +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): + + expected_error = TypeError + + deserialization_kwargs = { + "num_readers": "bar", # illegal value + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format = "tensorizer", + model_loader_extra_config=loader_tc.to_dict(),) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised(executor, + tensorizer.serialization.TensorDeserializer, + "__init__", + TypeError, + ) + +def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): + + deserialization_kwargs = { + "num_readers": 1, + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + stream_kwargs = { + "mode": "foo" + } + + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + stream_kwargs=stream_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format = "tensorizer", + model_loader_extra_config=loader_tc.to_dict(),) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised( + executor, + vllm.model_executor.model_loader.tensorizer, + "open_stream", + ValueError, + ) + +@pytest.mark.asyncio +async def test_serialize_and_serve_entrypoints(tmp_path): + model_ref = "facebook/opt-125m" + + suffix = "test" + try: + result = subprocess.run([ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", + model_ref, "serialize", "--serialized-directory", + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' + ], + check=True, + capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print("Tensorizing failed.") + print("STDOUT:\n", e.stdout) + print("STDERR:\n", e.stderr) + raise + + assert "Successfully serialized" in result.stdout + + # Next, try to serve with vllm serve + model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors" + + model_loader_extra_config = { + "tensorizer_uri": str(model_uri), + "stream_kwargs": { + "force_http": False, + }, + "deserialization_kwargs": { + "verify_hash": True, + "num_readers": 8, + } + } + + cmd = [ + "-m", + "vllm.entrypoints.cli.main", + "serve", + "--host", + "localhost", + "--load-format", + "tensorizer", + model_ref, + "--model-loader-extra-config", + json.dumps(model_loader_extra_config, indent=2) + ] + + proc = await asyncio.create_subprocess_exec( + sys.executable, + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + + try: + async with asyncio.timeout(180): + await proc.stdout.readuntil(b"Application startup complete.") + except asyncio.TimeoutError: + pytest.fail("Server did not start successfully") + finally: + proc.terminate() + await proc.communicate() + +@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) +def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, + illegal_value): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = { + "tensorizer_uri": str(model_path), + illegal_value: "foo" + } + + try: + vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=loader_tc, + ) + except RuntimeError: + out, err = capfd.readouterr() + combined_output = out + err + assert (f"ValueError: {illegal_value} is not an allowed " + f"Tensorizer argument.") in combined_output + + diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 442e4100fea1..a36bbf003ea1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -54,7 +54,7 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: def _parse_type(val: str) -> T: try: - if return_type is json.loads and not re.match("^{.*}$", val): + if return_type is json.loads and not re.match(r"(?s)^\s*{.*}\s*$", val): return cast(T, nullable_kvs(val)) return return_type(val) except ValueError as e: @@ -76,7 +76,7 @@ def _optional_type(val: str) -> Optional[T]: def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: - if not re.match("^{.*}$", val): + if not re.match(r"(?s)^\s*{.*}\s*$", val): return str(val) return optional_type(json.loads)(val) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index af5cebdf2a8b..46adeebd408b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -246,7 +246,7 @@ def check_unexpected_modules(modules: dict): tensorizer_args = tensorizer_config._construct_tensorizer_args() tensors = TensorDeserializer(lora_tensor_path, dtype=tensorizer_config.dtype, - **tensorizer_args.deserializer_params) + **tensorizer_args.deserialization_kwargs) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 7d335e5f7fab..83bf65b2cf14 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -105,7 +105,7 @@ def from_local_dir( "adapter_config.json") with open_stream(lora_config_path, mode="rb", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: config = json.load(f) logger.info("Successfully deserialized LoRA config from %s", diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 4c4502284a6a..40b3387c7e02 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -35,10 +35,6 @@ from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) - _read_stream, _write_stream = (partial( - open_stream, - mode=mode, - ) for mode in ("rb", "wb+")) except ImportError: tensorizer = PlaceholderModule("tensorizer") DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") @@ -50,9 +46,6 @@ get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") - _read_stream = tensorizer.placeholder_attr("_read_stream") - _write_stream = tensorizer.placeholder_attr("_write_stream") - __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', @@ -61,6 +54,15 @@ logger = init_logger(__name__) +def tensorizer_kwargs_arg(value): + loaded = json.loads(value) + if not isinstance(loaded, dict): + raise argparse.ArgumentTypeError( + f"Not deserializable to dict: {value}. serialization_kwargs and " + f"deserialization_kwargs must be " + f"deserializable from a JSON string to a dictionary. " + ) + return loaded class MetaTensorMode(TorchDispatchMode): @@ -146,6 +148,9 @@ class TensorizerConfig: hf_config: Optional[PretrainedConfig] = None dtype: Optional[Union[str, torch.dtype]] = None lora_dir: Optional[str] = None + stream_kwargs: Optional[dict[str, Any]] = None + serialization_kwargs: Optional[dict[str, Any]] = None + deserialization_kwargs: Optional[dict[str, Any]] = None _is_sharded: bool = False def __post_init__(self): @@ -160,6 +165,10 @@ def __post_init__(self): "provided.") self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) self.lora_dir = self.tensorizer_dir + if not self.serialization_kwargs: + self.serialization_kwargs = {} + if not self.deserialization_kwargs: + self.deserialization_kwargs = {} @classmethod def as_dict(cls, *args, **kwargs) -> dict[str, Any]: @@ -179,6 +188,9 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs": "s3_access_key_id": self.s3_access_key_id, "s3_secret_access_key": self.s3_secret_access_key, "s3_endpoint": self.s3_endpoint, + "stream_kwargs": self.stream_kwargs, + "serialization_kwargs": self.serialization_kwargs, + "deserialization_kwargs": self.deserialization_kwargs, } return TensorizerArgs(**tensorizer_args) # type: ignore @@ -205,7 +217,7 @@ def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): tensorizer_args = self._construct_tensorizer_args() return open_stream(self.tensorizer_uri, - **tensorizer_args.stream_params) + **tensorizer_args.stream_kwargs) def load_with_tensorizer(tensorizer_config: TensorizerConfig, @@ -225,6 +237,9 @@ class TensorizerArgs: s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None s3_endpoint: Optional[str] = None + stream_kwargs: Optional[dict[str, Any]] = None + deserialization_kwargs: Optional[dict[str, Any]] = None + serialization_kwargs: Optional[dict[str, Any]] = None """ Args for the TensorizerAgent class. These are used to configure the behavior of the TensorDeserializer when loading tensors from a serialized model. @@ -243,11 +258,20 @@ class TensorizerArgs: inferred as vLLM models. verify_hash: If True, the hashes of each tensor will be verified against the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. + raised if any of the hashes do not match. Passing parameters + to TensorizerSerializer and TensorDeserializer objects explicitly + in TensorizerArgs and TensorizerConfig is deprecated. This parameter + and others that are given to TensorDeserializer should now be + provided in the deserialization_kwargs dict. num_readers: Controls how many threads are allowed to read concurrently from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available + the number of readers based on the number of available. resources and model size. This greatly increases performance. + Passing parameters to TensorizerSerializer and TensorDeserializer + objects explicitly in TensorizerArgs and TensorizerConfig is + deprecated. This parameter and others that are given to + TensorDeserializer should now be provided in the + deserialization_kwargs dict. encryption_keyfile: File path to a binary file containing a binary key to use for decryption. `None` (the default) means no decryption. See the example script in @@ -258,6 +282,19 @@ class TensorizerArgs: be set via the S3_SECRET_ACCESS_KEY environment variable. s3_endpoint: The endpoint for the S3 bucket. Can also be set via the S3_ENDPOINT_URL environment variable. + stream_kwargs: Keyword arguments to pass to Tensorizer's + `stream_io.open_stream()` function, which the `TensorSerializer` and + `TensorDeserializer` objects use to work with files and streams. + See the Tensorizer page in vLLM's official documentation for more + information on available kwargs. + deserialization_kwargs: Additional keyword arguments to be passed + directly to Tensorizer's `TensorDeserializer`. See the Tensorizer + page in vLLM's official documentation for more information on + available kwargs. + serialization_kwargs: Additional keyword arguments to be passed + directly to Tensorizer's `TensorSerializer`. See the Tensorizer + page in vLLM's official documentation for more information on + available kwargs. """ def __post_init__(self): @@ -266,26 +303,29 @@ def __post_init__(self): self.s3_secret_access_key = (self.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY) self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL - self.stream_params = { + + self.stream_kwargs = { "s3_access_key_id": self.s3_access_key_id, "s3_secret_access_key": self.s3_secret_access_key, "s3_endpoint": self.s3_endpoint, + **(self.stream_kwargs or {}) } - self.deserializer_params = { + self.deserialization_kwargs = { "verify_hash": self.verify_hash, "encryption": self.encryption_keyfile, - "num_readers": self.num_readers + "num_readers": self.num_readers, + **(self.deserialization_kwargs or {}) } if self.encryption_keyfile: with open_stream( self.encryption_keyfile, - **self.stream_params, + **self.stream_kwargs, ) as stream: key = stream.read() decryption_params = DecryptionParams.from_key(key) - self.deserializer_params['encryption'] = decryption_params + self.deserialization_kwargs['encryption'] = decryption_params @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -424,9 +464,9 @@ def deserialize(self): Deserialize the model using the TensorDeserializer. This method is specifically for vLLM models using tensorizer's plaid_mode. - The deserializer makes use of tensorizer_args.stream_params + The deserializer makes use of tensorizer_args.stream_kwargs to configure the behavior of the stream when loading tensors from a - serialized model. The deserializer_params are used to configure the + serialized model. The deserialization_kwargs are used to configure the behavior of the TensorDeserializer when loading tensors themselves. Documentation on these params can be found in TensorizerArgs @@ -435,14 +475,15 @@ def deserialize(self): """ before_mem = get_mem_usage() start = time.perf_counter() - with _read_stream( + with open_stream( self.tensorizer_config.tensorizer_uri, - **self.tensorizer_args.stream_params + mode="rb", + **self.tensorizer_args.stream_kwargs ) as stream, TensorDeserializer( stream, dtype=self.tensorizer_config.dtype, - device=f'cuda:{torch.cuda.current_device()}', - **self.tensorizer_args.deserializer_params) as deserializer: + device=torch.device("cuda", torch.cuda.current_device()), + **self.tensorizer_args.deserialization_kwargs) as deserializer: deserializer.load_into_module(self.model) end = time.perf_counter() @@ -472,9 +513,9 @@ def tensorizer_weights_iterator( "examples/others/tensorize_vllm_model.py example script " "for serializing vLLM models.") - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + deserializer_args = tensorizer_args.deserialization_kwargs + stream_kwargs = tensorizer_args.stream_kwargs + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs) with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: yield from state.items() @@ -495,8 +536,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: """ tensorizer_args = tensorizer_config._construct_tensorizer_args() deserializer = TensorDeserializer(open_stream( - tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), - **tensorizer_args.deserializer_params, + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), + **tensorizer_args.deserialization_kwargs, lazy_load=True) if tensorizer_config.vllm_tensorized: logger.warning( @@ -527,8 +568,8 @@ def serialize_vllm_model( from vllm.distributed import get_tensor_model_parallel_rank output_file = output_file % get_tensor_model_parallel_rank() - with _write_stream(output_file, **tensorizer_args.stream_params) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) + with open_stream(output_file, mode="wb+", **tensorizer_args.stream_kwargs) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params, **tensorizer_config.serialization_kwargs) serializer.write_module(model) serializer.close() logger.info("Successfully serialized model to %s", str(output_file)) @@ -552,8 +593,9 @@ def tensorize_vllm_model(engine_args: EngineArgs, if generate_keyfile and (keyfile := tensorizer_config.encryption_keyfile) is not None: encryption_params = EncryptionParams.random() - with _write_stream( + with open_stream( keyfile, + mode="wb+", s3_access_key_id=tensorizer_config.s3_access_key_id, s3_secret_access_key=tensorizer_config.s3_secret_access_key, s3_endpoint=tensorizer_config.s3_endpoint, @@ -616,14 +658,14 @@ def tensorize_lora_adapter(lora_path: str, with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: f.write(json.dumps(config).encode("utf-8")) lora_uri = (f"{tensorizer_config.lora_dir}" f"/adapter_model.tensors") with open_stream(lora_uri, mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: serializer = TensorSerializer(f) serializer.write_state_dict(tensors) serializer.close() diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 2afe2b59e2f9..0f493dc40fb7 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -19,6 +19,19 @@ logger = init_logger(__name__) +BLACKLISTED_TENSORIZER_ARGS = { + "device", # vLLM decides this + "dtype", # vLLM decides this + "mode", # Not meant to be configurable by the user +} + +def validate_config(config: dict): + for k, v in config.items(): + if v is not None and k in BLACKLISTED_TENSORIZER_ARGS: + raise ValueError( + f"{k} is not an allowed Tensorizer argument." + ) + class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" @@ -28,6 +41,7 @@ def __init__(self, load_config: LoadConfig): if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: + validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( **load_config.model_loader_extra_config)