diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 1a14c52b2c..8cf5070b34 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,8 +3,8 @@ from pathlib import Path import torch -import transformers -from huggingface_hub import snapshot_download +from huggingface_hub import list_repo_files, snapshot_download +from transformers import modeling_utils from transformers.utils import is_offline_mode @@ -22,7 +22,12 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): print("Offline mode: forcing local_files_only=True") # Only download PyTorch weights by default - allow_patterns = ["*.bin"] + if any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)): + allow_patterns = ["*.bin"] + elif any( + ".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token) + ): # Some models like Falcon-180b are in only safetensors format + allow_patterns = ["*.safetensors"] # Download only on first process if local_rank in [-1, 0]: @@ -52,45 +57,25 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): - """ - Gets the list of files for the specified model checkpoint. - """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Logic for loading individual weights from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/trainer.py#L2061 - individual_weights = [ - os.path.join(cached_repo_dir, weight_name) - for weight_name in ( - transformers.modeling_utils.SAFE_WEIGHTS_NAME, - transformers.modeling_utils.WEIGHTS_NAME, - ) - ] - checkpoint_files = [] - for weight_file in individual_weights: - if os.path.isfile(weight_file): - checkpoint_files.append(weight_file) - break - if checkpoint_files: - return checkpoint_files - - # Code for loading sharded weights copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 - index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) - safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + # Extensions: .bin | .safetensors | .pt + # Creates a list of paths from all downloaded files in cache dir - index_present = os.path.isfile(index_file) - safe_index_present = os.path.isfile(safe_index_file) - - if not index_present and not safe_index_present: - filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) - raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") - - load_index = safe_index_file if safe_index_present else index_file + if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME) + elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME) + else: + (name, ext) = ("*", ".pt") - with open(load_index, "r", encoding="utf-8") as f: - index = json.load(f) + file_list = [ + str(entry) + for entry in Path(cached_repo_dir).rglob("*") + if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext)) + ] - file_list = set(index["weight_map"].values()) - return [os.path.join(cached_repo_dir, entry) for entry in file_list] + return file_list def write_checkpoints_json(model_name_or_path, local_rank, f, token=None):