From e3f8822057b02878e41178d41284c4ca71962567 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 31 Jan 2024 18:59:03 +0000 Subject: [PATCH 1/2] Add support for safetensors and sharded checkpoints --- optimum/habana/checkpoint_utils.py | 39 +++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 0fdd1c6566..2908f33c64 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +import transformers from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode @@ -56,10 +57,40 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt - # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] - return file_list + # Logic for loading individual weights from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/trainer.py#L2061 + individual_weights = [ + os.path.join(cached_repo_dir, weight_name) + for weight_name in ( + transformers.modeling_utils.SAFE_WEIGHTS_NAME, + transformers.modeling_utils.WEIGHTS_NAME, + ) + ] + checkpoint_files = [] + for weight_file in individual_weights: + if os.path.isfile(weight_file): + checkpoint_files.append(weight_file) + break + if checkpoint_files: + return checkpoint_files + + # Code for loading sharded weights copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 + index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not safe_index_present: + filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") + + load_index = safe_index_file if safe_index_present else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + file_list = set(index["weight_map"].values()) + return [os.path.join(cached_repo_dir, entry) for entry in file_list] def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): From 2619e264b4ccc6385847a75fe5c328dd2f117374 Mon Sep 17 00:00:00 2001 From: Taylor Jackle Spriggs Date: Wed, 21 Feb 2024 09:52:54 -0700 Subject: [PATCH 2/2] make style --- optimum/habana/checkpoint_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 2908f33c64..b4e4f8c390 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -81,7 +81,10 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): safe_index_present = os.path.isfile(safe_index_file) if not index_present and not safe_index_present: - filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + filenames = ( + transformers.modeling_utils.WEIGHTS_INDEX_NAME, + transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME, + ) raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") load_index = safe_index_file if safe_index_present else index_file