diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index e0fc139f5d..41feb2a37e 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -58,7 +58,9 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): # Extensions: .bin | .pt # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] + path = Path(cached_repo_dir) + globs = ["*.bin", "*.pt", "*.safetensors"] + file_list = [str(entry) for glob in globs for entry in Path(cached_repo_dir).rglob(glob) if entry.is_file()] return file_list