From 98fee776c4cdc8d524eaef8b9f879af3fbb312ea Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 9 Mar 2023 14:27:51 -0500 Subject: [PATCH 1/3] Add a progress bar for the total download of shards --- src/transformers/models/detr/configuration_detr.py | 1 + src/transformers/utils/hub.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index b3da5f86b016..955b71de1ec5 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -239,6 +239,7 @@ def hidden_size(self) -> int: @classmethod def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): """Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration. + Args: backbone_config ([`PretrainedConfig`]): The backbone configuration. diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 3403867eafe8..fea194585289 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -390,7 +390,7 @@ def cached_file( if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - if _commit_hash is not None: + if _commit_hash is not None and not force_download: # If the file is cached under that commit hash, we return it directly. resolved_file = try_to_load_from_cache( path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash @@ -913,7 +913,13 @@ def get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub cached_filenames = [] - for shard_filename in shard_filenames: + # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of + # downloaded (if interrupted). + last_shard = try_to_load_from_cache( + pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash + ) + show_progress_bar = last_shard is _CACHED_NO_EXIST or force_download + for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar): try: # Load from URL cached_filename = cached_file( From aa02a9d8ceb7bc6d95330b2635619648a4069520 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 9 Mar 2023 15:27:43 -0500 Subject: [PATCH 2/3] Check for no cache at all --- src/transformers/utils/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index fea194585289..b65c51febf27 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -918,7 +918,7 @@ def get_checkpoint_shard_files( last_shard = try_to_load_from_cache( pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash ) - show_progress_bar = last_shard is _CACHED_NO_EXIST or force_download + show_progress_bar = last_shard is None or last_shard is _CACHED_NO_EXIST or force_download for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar): try: # Load from URL From e170abaefa53672e71661eb936ea22a1e55f1c4f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 9 Mar 2023 15:29:38 -0500 Subject: [PATCH 3/3] Fix check --- src/transformers/utils/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index b65c51febf27..db00878c9ae4 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -918,7 +918,7 @@ def get_checkpoint_shard_files( last_shard = try_to_load_from_cache( pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash ) - show_progress_bar = last_shard is None or last_shard is _CACHED_NO_EXIST or force_download + show_progress_bar = last_shard is None or force_download for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar): try: # Load from URL