Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
adf4bd3
tests: Add support for unit testing within the worker processes
sangstar May 28, 2025
b349690
fix: Add a `._debug_ctx`, arbitrary serializer and deserializer kwargs
sangstar May 28, 2025
2a3375c
fix: Add `--serialization-kwargs` support to `tensorize_vllm_model.py…
sangstar May 28, 2025
2fe2570
fix: Rm pointless `._debug_ctx` attribute
sangstar May 28, 2025
66fe8aa
tests: Rm `_debug` in initializer
sangstar May 28, 2025
0dc3b6f
fix: Add test to confirm `deserialization_kwargs` passed to `TensorDe…
sangstar May 29, 2025
f85a816
fix: Add `stream_kwargs` and add test for it
sangstar May 29, 2025
6c862f5
tests: Update entrypoint test to include new kwargs, add other test
sangstar May 29, 2025
dcd5286
docs: Update md on using Tensorizer with vLLM
sangstar May 29, 2025
2b818ab
tests: Print outputs from subprocesses for test
sangstar May 29, 2025
927625a
docs: Fix link
sangstar May 29, 2025
0faa197
docs: Add more updates to md
sangstar May 29, 2025
b46b588
docs: Clarify model tensors location
sangstar May 29, 2025
b22eac1
fix: Allow parsing JSON strings with newlines
sangstar Jun 3, 2025
ca73096
tests: Confirm server starts successfully for `vllm serve` test
sangstar Jun 3, 2025
2626504
Update docs/models/extensions/tensorizer.md
sangstar Jun 10, 2025
44d25f9
Apply suggestions from code review
sangstar Jun 10, 2025
2cefde1
fix: Properly assert type for serialization/deserialization args
sangstar Jun 10, 2025
3eabef5
tests: Resolve review comments on tests
sangstar Jun 10, 2025
eaebcca
docs: Add clarifications on deprecated args, usage doc
sangstar Jun 10, 2025
b51ee42
Apply suggestions from code review
sangstar Jun 11, 2025
a0ff85b
Update docs/models/extensions/tensorizer.md
sangstar Jun 11, 2025
2a8a617
Update examples/others/tensorize_vllm_model.py
sangstar Jun 11, 2025
6127ddb
fix: Implement changes from second review
sangstar Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions docs/models/extensions/tensorizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,109 @@ title: Loading models with CoreWeave's Tensorizer
vLLM supports loading models with [CoreWeave's Tensorizer](https://docs.coreweave.com/coreweave-machine-learning-and-ai/inference/tensorizer).
vLLM model tensors that have been serialized to disk, an HTTP/HTTPS endpoint, or S3 endpoint can be deserialized
at runtime extremely quickly directly to the GPU, resulting in significantly
shorter Pod startup times and CPU memory usage. Tensor encryption is also supported.
lower Pod startup times and CPU memory usage. Tensor encryption is also supported.

For more information on CoreWeave's Tensorizer, please refer to
[CoreWeave's Tensorizer documentation](https://github.com/coreweave/tensorizer). For more information on serializing a vLLM model, as well a general usage guide to using Tensorizer with vLLM, see
the [vLLM example script](https://docs.vllm.ai/en/latest/examples/tensorize_vllm_model.html).
vLLM fully integrates Tensorizer in its model loading machinery. The
following will give a brief overview on how to get started with using
Tensorizer on vLLM.

## The basics
To load a model using Tensorizer, it first needs to be serialized by Tensorizer.
The example script in [examples/others/tensorize_vllm_model.py](https://github.com/vllm-project/vllm/blob/main/examples/others/tensorize_vllm_model.py)
takes care of this process.

The `TensorizerConfig` class is used to customize Tensorizer's behaviour,
defined in [vllm/model_executor/model_loader/tensorizer.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/tensorizer.py).
It is passed to any serialization or deserialization operation.
When loading with Tensorizer using the vLLM
library rather than through a model-serving entrypoint, it gets passed to
the `LLM` entrypoint class directly. Here's an example of loading a model
saved at `"s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors"`:

```python
from vllm import LLM
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

path_to_tensors = "s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors"

model_ref = "facebook/opt-125m"
tensorizer_config = TensorizerConfig(
tensorizer_uri=path_to_tensors
)

llm = LLM(
model_ref,
load_format="tensorizer",
model_loader_extra_config=tensorizer_config,
)
```

However, the above code will not function until you have successfully
serialized the model tensors with Tensorizer to get the `model.tensors`
file shown. The following section walks through an end-to-end example
of serializing `facebook/opt-125m` with the example script,
and then loading it for inference.

## Saving a vLLM model with Tensorizer
To save a model with Tensorizer, call the example script with the necessary
CLI arguments. The docstring for the script itself explains the CLI args
and how to use it properly in great detail, and we'll use one of the
examples from the docstring directly, assuming we want to save our model at
our S3 bucket example `s3://my-bucket`:

```bash
python examples/others/tensorize_vllm_model.py \
--model facebook/opt-125m \
serialize \
--serialized-directory s3://my-bucket \
--suffix v1
```

This saves the model tensors at
`s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors`, as well as all other
artifacts needed to load the model at that same directory.

## Serving the model using Tensorizer
Once the model is serialized where you want it, all that is needed is to
pass that directory with the model artifacts to `vllm serve` in the case above,
one can simply do:

```bash
vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 --load-format=tensorizer
```

Please note that object storage authentication is still required.
To authenticate, Tensorizer searches for a `~/.s3cfg` configuration file
([s3cmd](https://s3tools.org/kb/item14.htm)'s format),
or `~/.aws/config` and `~/.aws/credentials` files (`boto3` and the
`aws` CLI's format), but authentication can also be configured using
[any normal `boto3` environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables),
such as `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`,
and `AWS_ENDPOINT_URL_S3`.

If only the `model.tensors` file exists, you can load exclusively that file
by using the `--model-loader-extra-config` CLI parameter, which expects a
JSON string with all the kwargs for a `TensorizerConfig` configuration object.

```bash
#!/bin/sh

MODEL_LOADER_EXTRA_CONFIG='{
"tensorizer_uri": "s3://my-bucket/vllm/facebook/opt-125m/v1/model.tensors",
"stream_kwargs": {"force_http": false},
"deserialization_kwargs": {"verify_hash": true, "num_readers": 8}
}'

vllm serve facebook/opt-125m \
--load-format=tensorizer \
--model-loader-extra-config="$MODEL_LOADER_EXTRA_CONFIG"
```

Note in this case, if the directory to the model artifacts at
`s3://my-bucket/vllm/facebook/opt-125m/v1/` doesn't have all necessary model
artifacts to load, you'll want to pass `facebook/opt-125m` as the model tag like
it was done in the example script above. In this case, vLLM will take care of
resolving the other model artifacts by pulling them from HuggingFace Hub.

!!! note
Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`.
65 changes: 47 additions & 18 deletions examples/others/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
TensorizerConfig,
tensorize_lora_adapter,
tensorize_vllm_model,
tensorizer_kwargs_arg
)
from vllm.utils import FlexibleArgumentParser
import logging

logger = logging.getLogger()


# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -140,7 +145,7 @@ def parse_args():

)

subparsers = parser.add_subparsers(dest='command')
subparsers = parser.add_subparsers(dest='command', required=True)

serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
Expand Down Expand Up @@ -170,6 +175,14 @@ def parse_args():
"where `suffix` is given by `--suffix` or a random UUID if not "
"provided.")

serialize_parser.add_argument(
"--serialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's `TensorSerializer` during "
"serialization."))

serialize_parser.add_argument(
"--keyfile",
type=str,
Expand All @@ -195,11 +208,27 @@ def parse_args():
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))

deserialize_parser.add_argument(
"--deserialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's `TensorDeserializer` during "
"deserialization."))

TensorizerArgs.add_cli_args(deserialize_parser)

return parser.parse_args()


def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
cfg: TensorizerConfig):
for k, v in extra_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
logger.info(
f"Updating TensorizerConfig with {k} from "
f"--model-loader-extra-config provided"
)

def deserialize():
if args.lora_path:
Expand Down Expand Up @@ -266,13 +295,10 @@ def deserialize():
else:
keyfile = None

extra_config = {}
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None
extra_config = json.loads(args.model_loader_extra_config)


if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
Expand All @@ -293,21 +319,24 @@ def deserialize():
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)
serialization_kwargs=args.serialization_kwargs or {},
**credentials
)

if args.lora_path:
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
tensorize_lora_adapter(args.lora_path, tensorizer_config)

merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
tensorize_vllm_model(engine_args, tensorizer_config)

elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile=keyfile,
deserialization_kwargs=args.deserialization_kwargs or {},
**credentials
)

merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
deserialize()
10 changes: 8 additions & 2 deletions tests/entrypoints/openai/test_tensorizer_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
def server(model_uri, tensorize_model_and_lora):
model_loader_extra_config = {
"tensorizer_uri": model_uri,
"stream_kwargs": {
"force_http": False,
},
"deserialization_kwargs": {
"verify_hash": True,
"num_readers": 8,
}
}

## Start OpenAI API server
args = [
"--load-format", "tensorizer", "--device", "cuda",
"--model-loader-extra-config",
"--load-format", "tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--enable-lora"
]

Expand Down
4 changes: 3 additions & 1 deletion tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import subprocess
import sys
import json
from typing import Union

import pytest
Expand Down Expand Up @@ -210,7 +211,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
str(tp_size), "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
Expand Down
67 changes: 67 additions & 0 deletions tests/tensorizer_loader/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable

import pytest
import os

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_open_port, get_ip

from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase


@pytest.fixture(autouse=True)
def allow_insecure_serialization(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

@pytest.fixture(autouse=True)
def cleanup():
Expand All @@ -14,3 +30,54 @@ def cleanup():
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config


def assert_from_collective_rpc(engine: LLM,
closure: Callable,
closure_kwargs: dict):
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
return all(res)


# This is an object pulled from tests/v1/engine/test_engine_core.py
# Modified to strip the `load_model` method from its `_init_executor`
# method. It's purely used as a dummy utility to run methods that test
# Tensorizer functionality
class DummyExecutor(UniProcExecutor):

def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(
vllm_config=self.vllm_config,
rpc_rank=0
)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()
)
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":"
)
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs],))
self.collective_rpc("init_device")

@property
def max_concurrent_batches(self) -> int:
return 2

def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
Loading