diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index f6e68fc5c813..516b95a2cb00 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -3,6 +3,7 @@ import argparse import dataclasses import json +import logging import os import uuid @@ -14,10 +15,9 @@ TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model, - tensorizer_kwargs_arg + tensorizer_kwargs_arg, ) from vllm.utils import FlexibleArgumentParser -import logging logger = logging.getLogger() @@ -123,7 +123,7 @@ """ -def parse_args(): +def get_parser(): parser = FlexibleArgumentParser( description="An example script that can be used to serialize and " "deserialize vLLM models. These models " @@ -198,9 +198,17 @@ def parse_args(): deserialize_parser.add_argument( "--path-to-tensors", type=str, - required=True, + required=False, help="The local path or S3 URI to the model tensors to deserialize. ") + deserialize_parser.add_argument( + "--serialized-directory", + type=str, + required=False, + help="Directory with model artifacts for loading. Assumes a " + "model.tensors file exists therein. Can supersede " + "--path-to-tensors.") + deserialize_parser.add_argument( "--keyfile", type=str, @@ -218,7 +226,7 @@ def parse_args(): TensorizerArgs.add_cli_args(deserialize_parser) - return parser.parse_args() + return parser def merge_extra_config_with_tensorizer_config(extra_cfg: dict, cfg: TensorizerConfig): @@ -271,7 +279,8 @@ def deserialize(): if __name__ == '__main__': - args = parse_args() + parser = get_parser() + args = parser.parse_args() s3_access_key_id = (getattr(args, 's3_access_key_id', None) or os.environ.get("S3_ACCESS_KEY_ID", None)) @@ -300,6 +309,16 @@ def deserialize(): extra_config = json.loads(args.model_loader_extra_config) + tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir") + tensorizer_uri = args.path_to_tensors or extra_config.get("tensorizer_uri") + + if tensorizer_dir and tensorizer_uri: + parser.error("--serialized-directory and --path-to-tensors cannot both be provided") + + if not tensorizer_dir and not tensorizer_uri: + parser.error("Either --serialized-directory or --path-to-tensors must be provided") + + if args.command == "serialize": eng_args_dict = {f.name: getattr(args, f.name) for f in dataclasses.fields(EngineArgs)} @@ -308,7 +327,7 @@ def deserialize(): argparse.Namespace(**eng_args_dict) ) - input_dir = args.serialized_directory.rstrip('/') + input_dir = tensorizer_dir.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" if engine_args.tensor_parallel_size > 1: @@ -339,4 +358,4 @@ def deserialize(): ) merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) - deserialize() \ No newline at end of file + deserialize() diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 61fc92316f9c..c26f84930ec6 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import gc import json +import os import tempfile import openai @@ -57,6 +58,12 @@ def tensorize_model_and_lora(tmp_dir, model_uri): @pytest.fixture(scope="module") def server(model_uri, tensorize_model_and_lora): + # In this case, model_uri is a directory with a model.tensors + # file and all necessary model artifacts, particularly a + # HF `config.json` file. In this case, Tensorizer can infer the + # `TensorizerConfig` so --model-loader-extra-config can be completely + # omitted. + model_loader_extra_config = { "tensorizer_uri": model_uri, "stream_kwargs": { @@ -70,11 +77,13 @@ def server(model_uri, tensorize_model_and_lora): ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--model-loader-extra-config", + "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, + "--model-loader-extra-config", json.dumps(model_loader_extra_config), "--enable-lora" ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + model_dir = os.path.dirname(model_uri) + with RemoteOpenAIServer(model_dir, args) as remote_server: yield remote_server diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 13d7b4e7b7aa..79ba9b3a0f00 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import subprocess import sys -import json from typing import Union import pytest @@ -238,7 +237,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, tensor_parallel_size=2, max_loras=2) - tensorizer_config_dict = tensorizer_config.to_dict() + tensorizer_config_dict = tensorizer_config.to_serializable() print("lora adapter created") assert do_sample(loaded_vllm_model, diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 3ec0d1dbfe3e..d3b73ffd77fc 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -2,38 +2,60 @@ from typing import Callable import pytest -import os -from vllm import LLM +from vllm import LLM, EngineArgs from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.utils import get_distributed_init_method, get_open_port, get_ip - -from vllm.v1.engine.core import EngineCore +from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import UniProcExecutor -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.gpu_worker import Worker from vllm.worker.worker_base import WorkerWrapperBase +MODEL_REF = "facebook/opt-125m" + + +@pytest.fixture() +def model_ref(): + return MODEL_REF + @pytest.fixture(autouse=True) def allow_insecure_serialization(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + @pytest.fixture(autouse=True) def cleanup(): cleanup_dist_env_and_memory(shutdown_ray=True) +@pytest.fixture() +def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path): + + def noop(*args, **kwargs): + return None + + args = EngineArgs(model=model_ref) + tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors") + + monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop) + + tensorizer_mod.tensorize_vllm_model(args, tc) + yield tmp_path + + @pytest.fixture(autouse=True) def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") return config -def assert_from_collective_rpc(engine: LLM, - closure: Callable, +@pytest.fixture() +def model_path(model_ref, tmp_path): + yield tmp_path / model_ref / "model.tensors" + + +def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: dict): res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) return all(res) @@ -48,18 +70,14 @@ class DummyExecutor(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - self.driver_worker = WorkerWrapperBase( - vllm_config=self.vllm_config, - rpc_rank=0 - ) + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port() - ) + get_ip(), get_open_port()) local_rank = 0 # set local rank as the device index if specified device_info = self.vllm_config.device_config.device.__str__().split( - ":" - ) + ":") if len(device_info) > 1: local_rank = int(device_info[1]) rank = 0 @@ -71,7 +89,7 @@ def _init_executor(self) -> None: distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) - self.collective_rpc("init_worker", args=([kwargs],)) + self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") @property @@ -80,4 +98,4 @@ def max_concurrent_batches(self) -> int: def shutdown(self): if hasattr(self, 'thread_pool'): - self.thread_pool.shutdown(wait=False) \ No newline at end of file + self.thread_pool.shutdown(wait=False) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index bf572b656af0..f89ca5a17ad3 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,46 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import contextlib -import copy -import functools import gc import json import os import pathlib import subprocess import sys -from concurrent.futures import ThreadPoolExecutor, Future -from typing import Callable, Type, Union, Any +from typing import Any, Type from unittest.mock import MagicMock, patch -from dataclasses import dataclass import pytest import torch -from vllm import SamplingParams, LLM +import vllm.model_executor.model_loader.tensorizer +from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs - # yapf: disable -from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, - TensorSerializer, - is_vllm_tensorized, - load_with_tensorizer, - open_stream, - tensorize_vllm_model -) +from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + tensorize_vllm_model) from vllm.model_executor.model_loader.tensorizer_loader import ( BLACKLISTED_TENSORIZER_ARGS) -import vllm.model_executor.model_loader.tensorizer -from .conftest import DummyExecutor # yapf: enable -from vllm.utils import ( - PlaceholderModule, get_distributed_init_method, - get_open_port, get_ip, -) -from .conftest import assert_from_collective_rpc +from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH +from ..utils import VLLM_PATH, RemoteOpenAIServer +from .conftest import DummyExecutor, assert_from_collective_rpc try: import tensorizer @@ -67,12 +55,9 @@ class TensorizerCaughtError(Exception): # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) -model_ref = "facebook/opt-125m" -tensorize_model_for_testing_script = os.path.join( - os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") - -def patch_init_and_catch_error(self, obj, method_name, expected_error: Type[Exception]): +def patch_init_and_catch_error(self, obj, method_name, + expected_error: Type[Exception]): original = getattr(obj, method_name, None) if original is None: raise ValueError("Method '{}' not found.".format(method_name)) @@ -93,10 +78,15 @@ def assert_specific_tensorizer_error_is_raised( obj: Any, method_name: str, expected_error: Type[Exception], - ): +): with pytest.raises(TensorizerCaughtError): executor.collective_rpc(patch_init_and_catch_error, - args=(obj, method_name, expected_error,)) + args=( + obj, + method_name, + expected_error, + )) + def is_curl_installed(): try: @@ -149,11 +139,10 @@ def test_can_deserialize_s3(vllm_runner): @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( - vllm_runner, tmp_path): + model_ref, vllm_runner, tmp_path, model_path): args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: - model_path = tmp_path / (model_ref + ".tensors") - key_path = tmp_path / (model_ref + ".key") + key_path = tmp_path / model_ref / "model.key" write_keyfile(key_path) outputs = vllm_model.generate(prompts, sampling_params) @@ -179,9 +168,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, - tmp_path): + tmp_path, model_ref, + model_path): with hf_runner(model_ref) as hf_model: - model_path = tmp_path / (model_ref + ".tensors") max_tokens = 50 outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) with open_stream(model_path, "wb+") as stream: @@ -191,7 +180,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, with vllm_runner(model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, + tensorizer_uri=str(model_path), num_readers=1, )) as loaded_hf_model: deserialized_outputs = loaded_hf_model.generate_greedy( @@ -200,7 +189,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, assert outputs == deserialized_outputs -def test_load_without_tensorizer_load_format(vllm_runner, capfd): +def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( @@ -218,7 +207,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd): torch.cuda.empty_cache() -def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd): +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, + model_ref): model = None try: model = vllm_runner( @@ -276,7 +266,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( outputs = base_model.generate(prompts, sampling_params) # load model with two shards and serialize with encryption - model_path = str(tmp_path / (model_ref + "-%02d.tensors")) + model_path = str(tmp_path / model_ref / "model-%02d.tensors") key_path = tmp_path / (model_ref + ".key") tensorizer_config = TensorizerConfig( @@ -310,11 +300,10 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( @pytest.mark.flaky(reruns=3) -def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): +def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, + tmp_path, model_path): gc.collect() torch.cuda.empty_cache() - model_ref = "facebook/opt-125m" - model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) args = EngineArgs(model=model_ref) @@ -333,6 +322,30 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): assert outputs == deserialized_outputs + +def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref): + # For backwards compatibility, ensure Tensorizer can be still be loaded + # for inference by passing the model reference name, not a local/S3 dir, + # and the location of the model tensors + + model_dir = just_serialize_model_tensors + + extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"} + + ## Start OpenAI API server + args = [ + "--load-format", + "tensorizer", + "--model-loader-extra-config", + json.dumps(extra_config), + ] + + with RemoteOpenAIServer(model_ref, args): + # This test only concerns itself with being able to load the model + # and successfully initialize the server + pass + + def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): serialization_params = { @@ -342,10 +355,7 @@ def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path), serialization_kwargs=serialization_params) - llm = LLM( - model=model_ref, - ) - + llm = LLM(model=model_ref, ) def serialization_test(self, *args, **kwargs): # This is performed in the ephemeral worker process, so monkey-patching @@ -365,23 +375,21 @@ def tensorizer_serializer_wrapper(self, *args, **kwargs): tensorizer.serialization.TensorSerializer.__init__ = tensorizer_serializer_wrapper tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) - self.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + self.save_tensorized_model(tensorizer_config=tensorizer_config, ) return to_compare | original_dict == to_compare - kwargs = { - "tensorizer_config": config.to_dict() - } + kwargs = {"tensorizer_config": config.to_dict()} assert assert_from_collective_rpc(llm, serialization_test, kwargs) -def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( + tmp_path, capfd): expected_error = TypeError deserialization_kwargs = { - "num_readers": "bar", # illegal value + "num_readers": "bar", # illegal value } serialization_params = { @@ -403,17 +411,20 @@ def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(tmp_path, c engine_args = EngineArgs( model="facebook/opt-125m", - load_format = "tensorizer", - model_loader_extra_config=loader_tc.to_dict(),) + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_dict(), + ) vllm_config = engine_args.create_engine_config() executor = DummyExecutor(vllm_config) - assert_specific_tensorizer_error_is_raised(executor, - tensorizer.serialization.TensorDeserializer, - "__init__", - TypeError, - ) + assert_specific_tensorizer_error_is_raised( + executor, + tensorizer.serialization.TensorDeserializer, + "__init__", + TypeError, + ) + def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): @@ -433,10 +444,7 @@ def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) - stream_kwargs = { - "mode": "foo" - } - + stream_kwargs = {"mode": "foo"} loader_tc = TensorizerConfig( tensorizer_uri=str(model_path), @@ -446,8 +454,9 @@ def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): engine_args = EngineArgs( model="facebook/opt-125m", - load_format = "tensorizer", - model_loader_extra_config=loader_tc.to_dict(),) + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_dict(), + ) vllm_config = engine_args.create_engine_config() executor = DummyExecutor(vllm_config) @@ -459,6 +468,7 @@ def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): ValueError, ) + @pytest.mark.asyncio async def test_serialize_and_serve_entrypoints(tmp_path): model_ref = "facebook/opt-125m" @@ -498,14 +508,8 @@ async def test_serialize_and_serve_entrypoints(tmp_path): } cmd = [ - "-m", - "vllm.entrypoints.cli.main", - "serve", - "--host", - "localhost", - "--load-format", - "tensorizer", - model_ref, + "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost", + "--load-format", "tensorizer", model_ref, "--model-loader-extra-config", json.dumps(model_loader_extra_config, indent=2) ] @@ -517,7 +521,6 @@ async def test_serialize_and_serve_entrypoints(tmp_path): stderr=asyncio.subprocess.STDOUT, ) - try: async with asyncio.timeout(180): await proc.stdout.readuntil(b"Application startup complete.") @@ -527,6 +530,7 @@ async def test_serialize_and_serve_entrypoints(tmp_path): proc.terminate() await proc.communicate() + @pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, illegal_value): @@ -543,10 +547,7 @@ def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) - loader_tc = { - "tensorizer_uri": str(model_path), - illegal_value: "foo" - } + loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"} try: vllm_runner( @@ -559,5 +560,3 @@ def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, combined_output = out + err assert (f"ValueError: {illegal_value} is not an allowed " f"Tensorizer argument.") in combined_output - - diff --git a/vllm/config.py b/vllm/config.py index db35c848b33a..c68f22264f8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -636,7 +636,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, # If tokenizer is same as model, download to same directory if model == tokenizer: s3_model.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) self.tokenizer = s3_model.dir return @@ -644,7 +644,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, if is_s3(tokenizer): s3_tokenizer = S3Model() s3_tokenizer.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) self.tokenizer = s3_tokenizer.dir def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a36bbf003ea1..558f973b1078 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -54,7 +54,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: def _parse_type(val: str) -> T: try: - if return_type is json.loads and not re.match(r"(?s)^\s*{.*}\s*$", val): + if return_type is json.loads and not re.match( + r"(?s)^\s*{.*}\s*$", val): return cast(T, nullable_kvs(val)) return return_type(val) except ValueError as e: @@ -913,11 +914,33 @@ def create_model_config(self) -> ModelConfig: model_impl=self.model_impl, ) + def no_valid_tensorizer_args_in_model_loader_extra_config(self) -> bool: + + if self.model_loader_extra_config: + for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]: + try: + logger.info("Got %s", self.model_loader_extra_config) + self.model_loader_extra_config[allowed_to_pass] + return False + except KeyError: + pass + return True + def create_load_config(self) -> LoadConfig: if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" + if self.load_format == "tensorizer" and self.no_valid_tensorizer_args_in_model_loader_extra_config( + ): + logger.info("Inferring Tensorizer args from %s", self.model) + self.model_loader_extra_config = {"tensorizer_dir": self.model} + else: + logger.info( + "Using Tensorizer args from --model-loader-extra-config. " + "Note that you can now simply pass the S3 directory in the " + "model tag instead of providing the JSON string.") + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 46adeebd408b..9bedc39327ef 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -244,9 +244,10 @@ def check_unexpected_modules(modules: dict): lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, "adapter_model.tensors") tensorizer_args = tensorizer_config._construct_tensorizer_args() - tensors = TensorDeserializer(lora_tensor_path, - dtype=tensorizer_config.dtype, - **tensorizer_args.deserialization_kwargs) + tensors = TensorDeserializer( + lora_tensor_path, + dtype=tensorizer_config._model_cls_dtype, + **tensorizer_args.deserialization_kwargs) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 40b3387c7e02..2ab2658cd353 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -4,18 +4,18 @@ import contextlib import contextvars import dataclasses -import io import json import os +import tempfile import threading import time -from collections.abc import Generator -from dataclasses import dataclass -from functools import partial -from typing import Any, BinaryIO, Optional, Union +from collections.abc import Generator, MutableMapping +from dataclasses import asdict, dataclass, field, fields +from typing import Any, Optional, Union import regex as re import torch +from huggingface_hub import snapshot_download from torch import nn from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig @@ -54,16 +54,22 @@ logger = init_logger(__name__) + +def is_valid_deserialization_uri(uri: str) -> bool: + scheme = uri.lower().split("://")[0] + return scheme in {"s3", "http", "https"} or os.path.exists(uri) + + def tensorizer_kwargs_arg(value): loaded = json.loads(value) if not isinstance(loaded, dict): raise argparse.ArgumentTypeError( f"Not deserializable to dict: {value}. serialization_kwargs and " f"deserialization_kwargs must be " - f"deserializable from a JSON string to a dictionary. " - ) + f"deserializable from a JSON string to a dictionary. ") return loaded + class MetaTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -135,64 +141,142 @@ def wrapper(*args, **kwargs): @dataclass -class TensorizerConfig: - tensorizer_uri: Union[str, None] = None - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False +class TensorizerConfig(MutableMapping): + tensorizer_uri: Optional[str] = None + tensorizer_dir: Optional[str] = None + vllm_tensorized: Optional[bool] = None + verify_hash: Optional[bool] = None num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None s3_endpoint: Optional[str] = None - model_class: Optional[type[torch.nn.Module]] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Optional[Union[str, torch.dtype]] = None lora_dir: Optional[str] = None stream_kwargs: Optional[dict[str, Any]] = None serialization_kwargs: Optional[dict[str, Any]] = None deserialization_kwargs: Optional[dict[str, Any]] = None - _is_sharded: bool = False + _extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False, + default=None) + _model_cls: Optional[type[torch.nn.Module]] = field(init=False, + default=None) + _hf_config: Optional[PretrainedConfig] = field(init=False, default=None) + _model_cls_dtype: Optional[Union[str, torch.dtype]] = field(init=False, + default=None) + _is_sharded: bool = field(init=False, default=None) + """ + Args for the TensorizerConfig class. These are used to configure the + behavior of model serialization and deserialization using Tensorizer. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is a required + argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of `tensorizer_uri` + where the `model.tensors` file will be assumed to be in this + directory. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. + """ def __post_init__(self): # check if the configuration is for a sharded vLLM model self._is_sharded = isinstance(self.tensorizer_uri, str) \ and re.search(r'%0\dd', self.tensorizer_uri) is not None - if not self.tensorizer_uri and not self.lora_dir: - raise ValueError("tensorizer_uri must be provided.") - if not self.tensorizer_uri and self.lora_dir: - self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" - assert self.tensorizer_uri is not None, ("tensorizer_uri must be " - "provided.") - self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) - self.lora_dir = self.tensorizer_dir + + if self.tensorizer_dir and self.tensorizer_uri: + raise ValueError( + "Either tensorizer_dir or tensorizer_uri must be provided, " + "not both.") + if self.tensorizer_dir and self.lora_dir: + raise ValueError( + "Only one of tensorizer_dir or lora_dir may be specified. " + "Use lora_dir exclusively when serializing LoRA adapters, " + "and tensorizer_dir otherwise.") + if not self.tensorizer_uri: + if self.lora_dir: + self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" + elif self.tensorizer_dir: + self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors" + else: + raise ValueError("Unable to resolve tensorizer_uri. " + "A valid tensorizer_uri or tensorizer_dir " + "must be provided for deserialization, and a " + "valid tensorizer_uri, tensorizer_uri, or " + "lora_dir for serialization.") + else: + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + if not self.serialization_kwargs: self.serialization_kwargs = {} if not self.deserialization_kwargs: self.deserialization_kwargs = {} - @classmethod - def as_dict(cls, *args, **kwargs) -> dict[str, Any]: - cfg = TensorizerConfig(*args, **kwargs) - return dataclasses.asdict(cfg) + def to_serializable(self) -> dict[str, Any]: + # Due to TensorizerConfig needing to be msgpack-serializable, it needs + # support for morphing back and forth between itself and its dict + # representation + + # TensorizerConfig's representation as a dictionary is meant to be + # linked to TensorizerConfig in such a way that the following is + # technically initializable: + # TensorizerConfig(**my_tensorizer_cfg.to_serializable()) + + # This means the dict must not retain non-initializable parameters + # and post-init attribute states + + # Also don't want to retain private and unset parameters, so only retain + # not None values and public attributes + + raw_tc_dict = asdict(self) + blacklisted = [] - def to_dict(self) -> dict[str, Any]: - return dataclasses.asdict(self) + if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict: + blacklisted.append("tensorizer_dir") + + if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict: + blacklisted.append("lora_dir") + + tc_dict = {} + for k, v in raw_tc_dict.items(): + if (k not in blacklisted and k not in tc_dict + and not k.startswith("_") and v is not None): + tc_dict[k] = v + + return tc_dict def _construct_tensorizer_args(self) -> "TensorizerArgs": - tensorizer_args = { - "tensorizer_uri": self.tensorizer_uri, - "vllm_tensorized": self.vllm_tensorized, - "verify_hash": self.verify_hash, - "num_readers": self.num_readers, - "encryption_keyfile": self.encryption_keyfile, - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - "stream_kwargs": self.stream_kwargs, - "serialization_kwargs": self.serialization_kwargs, - "deserialization_kwargs": self.deserialization_kwargs, - } - return TensorizerArgs(**tensorizer_args) # type: ignore + return TensorizerArgs(self) # type: ignore def verify_with_parallel_config( self, @@ -219,6 +303,21 @@ def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs) + def __len__(self): + return len(fields(self)) + + def __iter__(self): + return (f.name for f in fields(self)) + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def __delitem__(self, key, /): + delattr(self, key) + def load_with_tensorizer(tensorizer_config: TensorizerConfig, **extra_kwargs) -> nn.Module: @@ -226,101 +325,37 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig, return tensorizer.deserialize() -@dataclass class TensorizerArgs: - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, - bytes, os.PathLike, int] - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False - num_readers: Optional[int] = None - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None - stream_kwargs: Optional[dict[str, Any]] = None - deserialization_kwargs: Optional[dict[str, Any]] = None - serialization_kwargs: Optional[dict[str, Any]] = None - """ - Args for the TensorizerAgent class. These are used to configure the behavior - of the TensorDeserializer when loading tensors from a serialized model. - - Args: - tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. This is a required field unless lora_dir is - provided and the config is meant to be used for the - `tensorize_lora_adapter` function. - vllm_tensorized: If True, indicates that the serialized model is a - vLLM model. This is used to determine the behavior of the - TensorDeserializer when loading tensors from a serialized model. - It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. Note that this is now - deprecated, as serialized vLLM models are now automatically - inferred as vLLM models. - verify_hash: If True, the hashes of each tensor will be verified against - the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. Passing parameters - to TensorizerSerializer and TensorDeserializer objects explicitly - in TensorizerArgs and TensorizerConfig is deprecated. This parameter - and others that are given to TensorDeserializer should now be - provided in the deserialization_kwargs dict. - num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available. - resources and model size. This greatly increases performance. - Passing parameters to TensorizerSerializer and TensorDeserializer - objects explicitly in TensorizerArgs and TensorizerConfig is - deprecated. This parameter and others that are given to - TensorDeserializer should now be provided in the - deserialization_kwargs dict. - encryption_keyfile: File path to a binary file containing a - binary key to use for decryption. `None` (the default) means - no decryption. See the example script in - examples/others/tensorize_vllm_model.py. - s3_access_key_id: The access key for the S3 bucket. Can also be set via - the S3_ACCESS_KEY_ID environment variable. - s3_secret_access_key: The secret access key for the S3 bucket. Can also - be set via the S3_SECRET_ACCESS_KEY environment variable. - s3_endpoint: The endpoint for the S3 bucket. Can also be set via the - S3_ENDPOINT_URL environment variable. - stream_kwargs: Keyword arguments to pass to Tensorizer's - `stream_io.open_stream()` function, which the `TensorSerializer` and - `TensorDeserializer` objects use to work with files and streams. - See the Tensorizer page in vLLM's official documentation for more - information on available kwargs. - deserialization_kwargs: Additional keyword arguments to be passed - directly to Tensorizer's `TensorDeserializer`. See the Tensorizer - page in vLLM's official documentation for more information on - available kwargs. - serialization_kwargs: Additional keyword arguments to be passed - directly to Tensorizer's `TensorSerializer`. See the Tensorizer - page in vLLM's official documentation for more information on - available kwargs. - """ - def __post_init__(self): - self.file_obj = self.tensorizer_uri - self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID - self.s3_secret_access_key = (self.s3_secret_access_key + def __init__(self, tensorizer_config: TensorizerConfig): + for k, v in tensorizer_config.items(): + setattr(self, k, v) + self.file_obj = tensorizer_config.tensorizer_uri + self.s3_access_key_id = tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY) - self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL + self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL self.stream_kwargs = { - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - **(self.stream_kwargs or {}) + "s3_access_key_id": tensorizer_config.s3_access_key_id, + "s3_secret_access_key": tensorizer_config.s3_secret_access_key, + "s3_endpoint": tensorizer_config.s3_endpoint, + **(tensorizer_config.stream_kwargs or {}) } self.deserialization_kwargs = { - "verify_hash": self.verify_hash, - "encryption": self.encryption_keyfile, - "num_readers": self.num_readers, - **(self.deserialization_kwargs or {}) + "verify_hash": + tensorizer_config.verify_hash, + "encryption": + tensorizer_config.encryption_keyfile, + "num_readers": + tensorizer_config.num_readers**( + tensorizer_config.deserialization_kwargs or {}) } if self.encryption_keyfile: with open_stream( - self.encryption_keyfile, + tensorizer_config.encryption_keyfile, **self.stream_kwargs, ) as stream: key = stream.read() @@ -424,14 +459,14 @@ def __init__(self, tensorizer_config: TensorizerConfig, vllm_config): self.model = self._init_model() def _init_model(self): - assert self.tensorizer_config.hf_config is not None - model_args = self.tensorizer_config.hf_config - model_args.torch_dtype = self.tensorizer_config.dtype - assert self.tensorizer_config.model_class is not None + assert self.tensorizer_config._hf_config is not None + model_args = self.tensorizer_config._hf_config + model_args.torch_dtype = self.tensorizer_config._model_cls_dtype + assert self.tensorizer_config._model_cls is not None # TODO: Do we need to consider old-style model class? with meta_tensor_mode(), set_current_vllm_config(self.vllm_config, check_compile=True): - return self.tensorizer_config.model_class( + return self.tensorizer_config._model_cls( vllm_config=self.vllm_config) def _resize_lora_embeddings(self): @@ -473,6 +508,13 @@ def deserialize(self): Returns: nn.Module: The deserialized model. """ + if not is_valid_deserialization_uri( + self.tensorizer_config.tensorizer_uri): + raise ValueError( + f"{self.tensorizer_config.tensorizer_uri} is not a valid " + f"tensorizer URI. Please check that the URI is correct. " + f"It must either point to a local existing file, or have a " + f"S3, HTTP or HTTPS scheme.") before_mem = get_mem_usage() start = time.perf_counter() with open_stream( @@ -481,7 +523,7 @@ def deserialize(self): **self.tensorizer_args.stream_kwargs ) as stream, TensorDeserializer( stream, - dtype=self.tensorizer_config.dtype, + dtype=self.tensorizer_config._model_cls_dtype, device=torch.device("cuda", torch.cuda.current_device()), **self.tensorizer_args.deserialization_kwargs) as deserializer: deserializer.load_into_module(self.model) @@ -548,13 +590,36 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: return ".vllm_tensorized_marker" in deserializer +def serialize_extra_artifacts(tensorizer_args: TensorizerArgs, + served_model_name: str) -> None: + + with tempfile.TemporaryDirectory() as tmpdir: + snapshot_download(served_model_name, + local_dir=tmpdir, + ignore_patterns=[ + "*.pt", "*.safetensors", "*.bin", "*.cache", + "*.gitattributes", "*.md" + ]) + for artifact in os.scandir(tmpdir): + if not artifact.is_file(): + continue + with open(artifact.path, "rb") as f: + with _write_stream( + f"{tensorizer_args.tensorizer_dir}/{artifact.name}", + **tensorizer_args.stream_params) as stream: + logger.info("Writing artifact %s", artifact.name) + stream.write(f.read()) + + def serialize_vllm_model( model: nn.Module, tensorizer_config: TensorizerConfig, + model_config: "ModelConfig", ) -> nn.Module: model.register_parameter( "vllm_tensorized_marker", nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() encryption_params = None @@ -568,10 +633,16 @@ def serialize_vllm_model( from vllm.distributed import get_tensor_model_parallel_rank output_file = output_file % get_tensor_model_parallel_rank() - with open_stream(output_file, mode="wb+", **tensorizer_args.stream_kwargs) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params, **tensorizer_config.serialization_kwargs) + with open_stream(output_file, mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + serializer = TensorSerializer(stream, + encryption=encryption_params, + **tensorizer_config.serialization_kwargs) serializer.write_module(model) serializer.close() + + serialize_extra_artifacts(tensorizer_args, model_config.served_model_name) + logger.info("Successfully serialized model to %s", str(output_file)) return model @@ -609,13 +680,13 @@ def tensorize_vllm_model(engine_args: EngineArgs, engine = LLMEngine.from_engine_args(engine_args) engine.model_executor.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) else: engine = V1LLMEngine.from_vllm_config(engine_config) engine.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 0f493dc40fb7..b7cfd7fba0eb 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -20,17 +20,16 @@ logger = init_logger(__name__) BLACKLISTED_TENSORIZER_ARGS = { - "device", # vLLM decides this + "device", # vLLM decides this "dtype", # vLLM decides this - "mode", # Not meant to be configurable by the user + "mode", # Not meant to be configurable by the user } + def validate_config(config: dict): for k, v in config.items(): if v is not None and k in BLACKLISTED_TENSORIZER_ARGS: - raise ValueError( - f"{k} is not an allowed Tensorizer argument." - ) + raise ValueError(f"{k} is not an allowed Tensorizer argument.") class TensorizerLoader(BaseModelLoader): @@ -93,9 +92,9 @@ def _load_model_serialized( model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype + tensorizer_config._model_cls = model_class + tensorizer_config._hf_config = model_config.hf_config + tensorizer_config._model_cls_dtype = model_config.dtype model = load_with_tensorizer(tensorizer_config, vllm_config=vllm_config) @@ -127,10 +126,12 @@ def load_model(self, vllm_config: VllmConfig, def save_model( model: torch.nn.Module, tensorizer_config: Union[TensorizerConfig, dict], + model_config: ModelConfig, ) -> None: if isinstance(tensorizer_config, dict): tensorizer_config = TensorizerConfig(**tensorizer_config) serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, + model_config=model_config, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 910c0e80bb31..9bc7148ade7a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1558,6 +1558,7 @@ def save_tensorized_model( TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config ) def _get_prompt_logprobs_dict( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8c968faa7810..4a8193188eb0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1245,6 +1245,7 @@ def save_tensorized_model( TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def get_max_block_per_batch(self) -> int: