Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions examples/others/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import dataclasses
import json
import logging
import os
import uuid

Expand All @@ -14,10 +15,9 @@
TensorizerConfig,
tensorize_lora_adapter,
tensorize_vllm_model,
tensorizer_kwargs_arg
tensorizer_kwargs_arg,
)
from vllm.utils import FlexibleArgumentParser
import logging

logger = logging.getLogger()

Expand Down Expand Up @@ -123,7 +123,7 @@
"""


def parse_args():
def get_parser():
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
Expand Down Expand Up @@ -198,9 +198,17 @@ def parse_args():
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
required=False,
help="The local path or S3 URI to the model tensors to deserialize. ")

deserialize_parser.add_argument(
"--serialized-directory",
type=str,
required=False,
help="Directory with model artifacts for loading. Assumes a "
"model.tensors file exists therein. Can supersede "
"--path-to-tensors.")

deserialize_parser.add_argument(
"--keyfile",
type=str,
Expand All @@ -218,7 +226,7 @@ def parse_args():

TensorizerArgs.add_cli_args(deserialize_parser)

return parser.parse_args()
return parser

def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
cfg: TensorizerConfig):
Expand Down Expand Up @@ -271,7 +279,8 @@ def deserialize():


if __name__ == '__main__':
args = parse_args()
parser = get_parser()
args = parser.parse_args()

s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
Expand Down Expand Up @@ -300,6 +309,16 @@ def deserialize():
extra_config = json.loads(args.model_loader_extra_config)


tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir")
tensorizer_uri = args.path_to_tensors or extra_config.get("tensorizer_uri")

if tensorizer_dir and tensorizer_uri:
parser.error("--serialized-directory and --path-to-tensors cannot both be provided")

if not tensorizer_dir and not tensorizer_uri:
parser.error("Either --serialized-directory or --path-to-tensors must be provided")


if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
Expand All @@ -308,7 +327,7 @@ def deserialize():
argparse.Namespace(**eng_args_dict)
)

input_dir = args.serialized_directory.rstrip('/')
input_dir = tensorizer_dir.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1:
Expand Down Expand Up @@ -339,4 +358,4 @@ def deserialize():
)

merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
deserialize()
deserialize()
13 changes: 11 additions & 2 deletions tests/entrypoints/openai/test_tensorizer_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import gc
import json
import os
import tempfile

import openai
Expand Down Expand Up @@ -57,6 +58,12 @@ def tensorize_model_and_lora(tmp_dir, model_uri):

@pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora):
# In this case, model_uri is a directory with a model.tensors
# file and all necessary model artifacts, particularly a
# HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.

model_loader_extra_config = {
"tensorizer_uri": model_uri,
"stream_kwargs": {
Expand All @@ -70,11 +77,13 @@ def server(model_uri, tensorize_model_and_lora):

## Start OpenAI API server
args = [
"--load-format", "tensorizer", "--model-loader-extra-config",
"--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
"--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--enable-lora"
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server


Expand Down
3 changes: 1 addition & 2 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import subprocess
import sys
import json
from typing import Union

import pytest
Expand Down Expand Up @@ -238,7 +237,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tensor_parallel_size=2,
max_loras=2)

tensorizer_config_dict = tensorizer_config.to_dict()
tensorizer_config_dict = tensorizer_config.to_serializable()

print("lora adapter created")
assert do_sample(loaded_vllm_model,
Expand Down
58 changes: 38 additions & 20 deletions tests/tensorizer_loader/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,60 @@
from typing import Callable

import pytest
import os

from vllm import LLM
from vllm import LLM, EngineArgs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_open_port, get_ip

from vllm.v1.engine.core import EngineCore
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase

MODEL_REF = "facebook/opt-125m"


@pytest.fixture()
def model_ref():
return MODEL_REF


@pytest.fixture(autouse=True)
def allow_insecure_serialization(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


@pytest.fixture(autouse=True)
def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)


@pytest.fixture()
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):

def noop(*args, **kwargs):
return None

args = EngineArgs(model=model_ref)
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")

monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)

tensorizer_mod.tensorize_vllm_model(args, tc)
yield tmp_path


@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config


def assert_from_collective_rpc(engine: LLM,
closure: Callable,
@pytest.fixture()
def model_path(model_ref, tmp_path):
yield tmp_path / model_ref / "model.tensors"


def assert_from_collective_rpc(engine: LLM, closure: Callable,
closure_kwargs: dict):
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
return all(res)
Expand All @@ -48,18 +70,14 @@ class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(
vllm_config=self.vllm_config,
rpc_rank=0
)
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()
)
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":"
)
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
Expand All @@ -71,7 +89,7 @@ def _init_executor(self) -> None:
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs],))
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")

@property
Expand All @@ -80,4 +98,4 @@ def max_concurrent_batches(self) -> int:

def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
self.thread_pool.shutdown(wait=False)
Loading