diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index f1fde58fc2de..e2fcc504642c 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Iterable, List, NewType, Optional, Tuple, Union -from sparsezoo import Zoo +from sparsezoo import Model from .utils.logging import get_logger @@ -263,8 +263,8 @@ def _download_dataclass_zoo_stub_files(data_class: DataClass): logger.info(f"Downloading framework files for SparseZoo stub: {val}") - zoo_model = Zoo.load_model_from_stub(val) - framework_file_paths = zoo_model.download_framework_files() + zoo_model = Model(val) + framework_file_paths = [file.path for file in zoo_model.training.default.files] assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}" framework_file_names = [os.path.basename(path) for path in framework_file_paths] if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):