Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion optimum/habana/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None):

# Extensions: .bin | .pt
# Creates a list of paths from all downloaded files in cache dir
file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()]
path = Path(cached_repo_dir)
globs = ["*.bin", "*.pt", "*.safetensors"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need some kind of if else check ; cached repo dir can have both .safetensors and .pt files.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that is already there loads .bin and .pt files, can the cached repo have both of those as well? I might be misunderstanding something

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can have both.

file_list = [str(entry) for glob in globs for entry in Path(cached_repo_dir).rglob(glob) if entry.is_file()]
return file_list


Expand Down