diff --git a/tests/tgis/__init__.py b/tests/tgis/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tgis/test_hub.py b/tests/tgis/test_hub.py new file mode 100644 index 000000000000..5ecde4bc67e5 --- /dev/null +++ b/tests/tgis/test_hub.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import pytest +from huggingface_hub.utils import LocalEntryNotFoundError + +from vllm.tgis_utils.hub import (convert_files, download_weights, weight_files, + weight_hub_files) + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + local_pt_files = download_weights(model_id, extension=".bin") + local_pt_files = [Path(p) for p in local_pt_files] + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files, discard_names=[]) + + found_st_files = weight_files(model_id) + + assert all([str(p) in found_st_files for p in local_st_files]) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [ + f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73) + ] + + +def test_weight_hub_files_empty(): + filenames = weight_hub_files("bigscience/bloom", ".errors") + assert filenames == [] + + +def test_download_weights(): + files = download_weights("bigscience/bloom-560m") + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") \ No newline at end of file diff --git a/vllm/scripts.py b/vllm/scripts.py index 3f334be925ee..f2ee45abc042 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -3,6 +3,7 @@ import os import signal import sys +from pathlib import Path from typing import Optional from openai import OpenAI @@ -11,6 +12,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser +from vllm.tgis_utils.scripts import tgis_cli def registrer_signal_handlers(): @@ -142,6 +144,37 @@ def main(): "used for models that support system prompts.")) chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + download_weights_parser = subparsers.add_parser( + "download-weights", + help=("Download the weights of a given model"), + usage="vllm download-weights [options]") + download_weights_parser.add_argument("model_name") + download_weights_parser.add_argument("--revision") + download_weights_parser.add_argument("--token") + download_weights_parser.add_argument("--extension", default=".safetensors") + download_weights_parser.add_argument("--auto_convert", default=True) + download_weights_parser.set_defaults(dispatch_function=tgis_cli, + command="download-weights") + + convert_to_safetensors_parser = subparsers.add_parser( + "convert-to-safetensors", + help=("Convert model weights to safetensors"), + usage="vllm convert-to-safetensors [options]") + convert_to_safetensors_parser.add_argument("model_name") + convert_to_safetensors_parser.add_argument("--revision") + convert_to_safetensors_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-safetensors") + + convert_to_fast_tokenizer_parser = subparsers.add_parser( + "convert-to-fast-tokenizer", + help=("Convert to fast tokenizer"), + usage="vllm convert-to-fast-tokenizer [options]") + convert_to_fast_tokenizer_parser.add_argument("model_name") + convert_to_fast_tokenizer_parser.add_argument("--revision") + convert_to_fast_tokenizer_parser.add_argument("--output_path") + convert_to_fast_tokenizer_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-fast-tokenizer") + args = parser.parse_args() # One of the sub commands should be executed. if hasattr(args, "dispatch_function"): diff --git a/vllm/tgis_utils/hub.py b/vllm/tgis_utils/hub.py new file mode 100644 index 000000000000..4361b189fdea --- /dev/null +++ b/vllm/tgis_utils/hub.py @@ -0,0 +1,270 @@ +import concurrent +import datetime +import glob +import json +import logging +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional + +import torch +from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import LocalEntryNotFoundError +from safetensors.torch import (_find_shared_tensors, _is_complete, load_file, + save_file) +from tqdm import tqdm + +TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE") == "true" +logger = logging.getLogger(__name__) + + +def weight_hub_files(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Get the safetensors filenames on the hub""" + exts = [extension] if isinstance(extension, str) else extension + api = HfApi() + info = api.model_info(model_name, revision=revision, token=auth_token) + filenames = [ + s.rfilename for s in info.siblings if any( + s.rfilename.endswith(ext) and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename and "args" not in s.rfilename + and "training" not in s.rfilename for ext in exts) + ] + return filenames + + +def weight_files(model_name, extension=".safetensors", revision=None): + """Get the local safetensors filenames""" + filenames = weight_hub_files(model_name, extension) + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache(model_name, + filename=filename, + revision=revision) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_name} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `vllm \ + download-weights {model_name}` first.") + files.append(cache_file) + + return files + + +def download_weights(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Download the safetensors files from the hub""" + filenames = weight_hub_files(model_name, + extension, + revision=revision, + auth_token=auth_token) + + download_function = partial( + hf_hub_download, + repo_id=model_name, + local_files_only=False, + revision=revision, + token=auth_token, + ) + + print(f"Downloading {len(filenames)} files for model {model_name}") + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(download_function, filename=filename) + for filename in filenames + ] + files = [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), + total=len(futures)) + ] + + return files + + +def get_model_path(model_name: str, revision: Optional[str] = None): + """Get path to model dir in local huggingface hub (model) cache""" + config_file = "config.json" + err = None + try: + config_path = try_to_load_from_cache( + model_name, + config_file, + cache_dir=os.getenv("TRANSFORMERS_CACHE" + ), # will fall back to HUGGINGFACE_HUB_CACHE + revision=revision, + ) + if config_path is not None: + return config_path.removesuffix(f"/{config_file}") + except ValueError as e: + err = e + + if os.path.isfile(f"{model_name}/{config_file}"): + return model_name # Just treat the model name as an explicit model path + + if err is not None: + raise err + + raise ValueError( + f"Weights not found in local cache for model {model_name}") + + +def local_weight_files(model_path: str, extension=".safetensors"): + """Get the local safetensors filenames""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}") + + +def local_index_files(model_path: str, extension=".safetensors"): + """Get the local .index.json filename""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}.index.json") + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + # _find_shared_tensors returns a list of sets of names of tensors that + # have the same data, including sets with one element that aren't shared + if len(shared) == 1: + continue + + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError(f"Error while trying to find names to remove \ + to save state dict, but found no suitable name to \ + keep for saving amongst: {shared}. None is covering \ + the entire storage.Refusing to save/load the model \ + since you could be storing much more \ + memory than needed. Please refer to\ + https://huggingface.co/docs/safetensors/torch_shared_tensors \ + for more information. \ + Or open an issue.") + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): + """ + Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). + """ + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_index_file(source_file: Path, dest_file: Path, + pt_files: List[Path], sf_files: List[Path]): + weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)} + + logger.info( + "Converting pytorch .bin.index.json files to .safetensors.index.json") + with open(source_file, "r") as f: + index = json.load(f) + + index["weight_map"] = { + k: weight_file_map[v] + for k, v in index["weight_map"].items() + } + + with open(dest_file, "w") as f: + json.dump(index, f, indent=4) + + +def convert_files(pt_files: List[Path], + sf_files: List[Path], + discard_names: List[str] = None): + assert len(pt_files) == len(sf_files) + + # Filter non-inference files + pairs = [ + p for p in zip(pt_files, sf_files) if not any(s in p[0].name for s in [ + "arguments", + "args", + "training", + "optimizer", + "scheduler", + "index", + ]) + ] + + N = len(pairs) + + if N == 0: + logger.warning("No pytorch .bin weight files found to convert") + return + + logger.info("Converting %d pytorch .bin files to .safetensors...", N) + + for i, (pt_file, sf_file) in enumerate(pairs): + file_count = (i + 1) / N + logger.info('Converting: [%d] "$s"', file_count, pt_file.name) + start = datetime.datetime.now() + convert_file(pt_file, sf_file, discard_names) + elapsed = datetime.datetime.now() - start + logger.info('Converted: [%d] "%s" -- Took: %d', file_count, + sf_file.name, elapsed) diff --git a/vllm/tgis_utils/scripts.py b/vllm/tgis_utils/scripts.py new file mode 100644 index 000000000000..e1c56e9ec8c9 --- /dev/null +++ b/vllm/tgis_utils/scripts.py @@ -0,0 +1,96 @@ +# The CLI entrypoint to vLLM. +import argparse +import os +import signal +import sys +from pathlib import Path +from typing import Optional + +from vllm.model_executor.model_loader.weight_utils import convert_bin_to_safetensor_file +from vllm.scripts import registrer_signal_handlers + + +def tgis_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + if args.command == "download-weights": + download_weights(args.model_name, args.revision, args.token, + args.extension, args.auto_convert) + elif args.command == "convert-to-safetensors": + convert_bin_to_safetensor_file(args.model_name, args.revision) + elif args.command == "convert-to-fast-tokenizer": + convert_to_fast_tokenizer(args.model_name, args.revision, + args.output_path) + + +def download_weights( + model_name: str, + revision: Optional[str] = None, + token: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, +) -> None: + from vllm.tgis_utils import hub + + print(extension) + meta_exts = [".json", ".py", ".model", ".md"] + + extensions = extension.split(",") + + if len(extensions) == 1 and extensions[0] not in meta_exts: + extensions.extend(meta_exts) + + files = hub.download_weights(model_name, + extensions, + revision=revision, + auth_token=token) + + if auto_convert and ".safetensors" in extensions: + if not hub.local_weight_files(hub.get_model_path(model_name, revision), + ".safetensors"): + if ".bin" not in extensions: + print(".safetensors weights not found, \ + downloading pytorch weights to convert...") + hub.download_weights(model_name, + ".bin", + revision=revision, + auth_token=token) + + print(".safetensors weights not found, \ + converting from pytorch weights...") + convert_bin_to_safetensor_file(model_name, revision) + elif not any(f.endswith(".safetensors") for f in files): + print(".safetensors weights not found on hub, \ + but were found locally. Remove them first to re-convert") + if auto_convert: + convert_to_fast_tokenizer(model_name, revision) + + +def convert_to_fast_tokenizer( + model_name: str, + revision: Optional[str] = None, + output_path: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Check for existing "tokenizer.json" + model_path = hub.get_model_path(model_name, revision) + + if os.path.exists(os.path.join(model_path, "tokenizer.json")): + print(f"Model {model_name} already has a fast tokenizer") + return + + if output_path is not None: + if not os.path.isdir(output_path): + print(f"Output path {output_path} must exist and be a directory") + return + else: + output_path = model_path + + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, + revision=revision) + tokenizer.save_pretrained(output_path) + + print(f"Saved tokenizer to {output_path}")