Skip to content

Commit 4b01501

Browse files
authored
Merge pull request #807 from NatLibFi/add-trust-option-to-download-cli-command
Add `--trust-repo` option to `download` CLI command
2 parents 1e2c43c + 654a376 commit 4b01501

File tree

4 files changed

+124
-10
lines changed

4 files changed

+124
-10
lines changed

annif/cli.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,9 @@ def run_upload(
633633
that match the given `project_ids_pattern` to archive files, and uploads the
634634
archives along with the project configurations to the specified Hugging Face
635635
Hub repository. An authentication token and commit message can be given with
636-
options. If the README.md does not exist in the repository it is created with
637-
default contents and metadata of the uploaded projects, if it exists, its
638-
metadata are updated as necessary.
636+
options. If the README.md does not exist in the repository it is
637+
created with default contents and metadata of the uploaded projects, if it exists,
638+
its metadata are updated as necessary.
639639
"""
640640
from huggingface_hub import HfApi
641641
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
@@ -692,8 +692,14 @@ def run_upload(
692692
is_flag=True,
693693
help="Replace an existing project/vocabulary/config with the downloaded one",
694694
)
695+
@click.option(
696+
"--trust-repo",
697+
default=False,
698+
is_flag=True,
699+
help="Allow download from the repository even when it has no entries in the cache",
700+
)
695701
@cli_util.common_options
696-
def run_download(project_ids_pattern, repo_id, token, revision, force):
702+
def run_download(project_ids_pattern, repo_id, token, revision, force, trust_repo):
697703
"""
698704
Download selected projects and their vocabularies from a Hugging Face Hub
699705
repository.
@@ -702,10 +708,14 @@ def run_download(project_ids_pattern, repo_id, token, revision, force):
702708
configuration files of the projects that match the given
703709
`project_ids_pattern` from the specified Hugging Face Hub repository and
704710
unzips the archives to `data/` directory and places the configuration files
705-
to `projects.d/` directory. An authentication token and revision can
706-
be given with options.
711+
to `projects.d/` directory. An authentication token and revision can be given with
712+
options. If the repository hasn’t been used for downloads previously
713+
(i.e., it doesn’t appear in the Hugging Face Hub cache on local system), the
714+
`--trust-repo` option needs to be used.
707715
"""
708716

717+
hfh_util.check_is_download_allowed(trust_repo, repo_id)
718+
709719
project_ids = hfh_util.get_matching_project_ids_from_hf_hub(
710720
project_ids_pattern, repo_id, token, revision
711721
)

annif/hfh_util.py

+29
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,35 @@
2424
logger = annif.logger
2525

2626

27+
def check_is_download_allowed(trust_repo, repo_id):
28+
"""Check if downloading from the specified repository is allowed based on the trust
29+
option and cache status."""
30+
if trust_repo:
31+
logger.warning(
32+
f'Download allowed from "{repo_id}" because "--trust-repo" flag is used.'
33+
)
34+
return
35+
if _is_repo_in_cache(repo_id):
36+
logger.debug(
37+
f'Download allowed from "{repo_id}" because repo is already in cache.'
38+
)
39+
return
40+
raise OperationFailedException(
41+
f'Cannot download projects from untrusted repo "{repo_id}"'
42+
)
43+
44+
45+
def _is_repo_in_cache(repo_id):
46+
from huggingface_hub import CacheNotFound, scan_cache_dir
47+
48+
try:
49+
cache = scan_cache_dir()
50+
except CacheNotFound as err:
51+
logger.debug(str(err) + "\nNo HFH cache found.")
52+
return False
53+
return repo_id in [info.repo_id for info in cache.repos]
54+
55+
2756
def get_matching_projects(pattern: str) -> list[AnnifProject]:
2857
"""
2958
Get projects that match the given pattern.

tests/test_cli.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -1149,10 +1149,30 @@ def test_upload_nonexistent_repo():
11491149
assert "Repository Not Found for url:" in failed_result.output
11501150

11511151

1152+
@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=False)
1153+
def test_download_not_allowed_default(mock_is_repo_in_cache):
1154+
# Default of --trust-repo is False
1155+
failed_result = runner.invoke(
1156+
annif.cli.cli,
1157+
[
1158+
"download",
1159+
"dummy-fi",
1160+
"dummy-repo",
1161+
],
1162+
)
1163+
assert failed_result.exception
1164+
assert failed_result.exit_code != 0
1165+
assert (
1166+
'Cannot download projects from untrusted repo "dummy-repo"'
1167+
in failed_result.output
1168+
)
1169+
1170+
11521171
def hf_hub_download_mock_side_effect(filename, repo_id, token, revision):
11531172
return "tests/huggingface-cache/" + filename # Mocks the downloaded file paths
11541173

11551174

1175+
@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True)
11561176
@mock.patch(
11571177
"huggingface_hub.list_repo_files",
11581178
return_value=[ # Mocks the filenames in repo
@@ -1170,7 +1190,11 @@ def hf_hub_download_mock_side_effect(filename, repo_id, token, revision):
11701190
)
11711191
@mock.patch("annif.hfh_util.copy_project_config")
11721192
def test_download_dummy_fi(
1173-
copy_project_config, hf_hub_download, list_repo_files, testdatadir
1193+
copy_project_config,
1194+
hf_hub_download,
1195+
list_repo_files,
1196+
check_is_download_allowed,
1197+
testdatadir,
11741198
):
11751199
result = runner.invoke(
11761200
annif.cli.cli,
@@ -1211,6 +1235,7 @@ def test_download_dummy_fi(
12111235
]
12121236

12131237

1238+
@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True)
12141239
@mock.patch(
12151240
"huggingface_hub.list_repo_files",
12161241
return_value=[ # Mock filenames in repo
@@ -1228,7 +1253,11 @@ def test_download_dummy_fi(
12281253
)
12291254
@mock.patch("annif.hfh_util.copy_project_config")
12301255
def test_download_dummy_fi_and_en(
1231-
copy_project_config, hf_hub_download, list_repo_files, testdatadir
1256+
copy_project_config,
1257+
hf_hub_download,
1258+
list_repo_files,
1259+
check_is_download_allowed,
1260+
testdatadir,
12321261
):
12331262
result = runner.invoke(
12341263
annif.cli.cli,
@@ -1285,6 +1314,7 @@ def test_download_dummy_fi_and_en(
12851314
]
12861315

12871316

1317+
@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True)
12881318
@mock.patch(
12891319
"huggingface_hub.list_repo_files",
12901320
side_effect=HFValidationError,
@@ -1293,8 +1323,7 @@ def test_download_dummy_fi_and_en(
12931323
"huggingface_hub.hf_hub_download",
12941324
)
12951325
def test_download_list_repo_files_failed(
1296-
hf_hub_download,
1297-
list_repo_files,
1326+
hf_hub_download, list_repo_files, check_is_download_allowed
12981327
):
12991328
failed_result = runner.invoke(
13001329
annif.cli.cli,
@@ -1311,6 +1340,7 @@ def test_download_list_repo_files_failed(
13111340
assert not hf_hub_download.called
13121341

13131342

1343+
@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True)
13141344
@mock.patch(
13151345
"huggingface_hub.list_repo_files",
13161346
return_value=[ # Mock filenames in repo
@@ -1326,6 +1356,7 @@ def test_download_list_repo_files_failed(
13261356
def test_download_hf_hub_download_failed(
13271357
hf_hub_download,
13281358
list_repo_files,
1359+
check_is_download_allowed,
13291360
):
13301361
failed_result = runner.invoke(
13311362
annif.cli.cli,

tests/test_hfh_util.py

+44
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,60 @@
11
"""Unit test module for Hugging Face Hub utilities."""
22

33
import io
4+
import logging
45
import os.path
56
import zipfile
67
from datetime import datetime, timezone
78
from unittest import mock
89

910
import huggingface_hub
11+
import pytest
1012
from huggingface_hub.utils import EntryNotFoundError
1113

1214
import annif.hfh_util
1315
from annif.config import AnnifConfigCFG
16+
from annif.exception import OperationFailedException
17+
18+
19+
@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=False)
20+
def test_download_allowed_trust_repo(mock_is_repo_in_cache, caplog):
21+
trust_repo = True
22+
repo_id = "dummy-repo"
23+
24+
with caplog.at_level(logging.WARNING, logger="annif"):
25+
annif.hfh_util.check_is_download_allowed(trust_repo, repo_id)
26+
assert (
27+
'Download allowed from "dummy-repo" because "--trust-repo" flag is used.'
28+
in caplog.text
29+
)
30+
31+
32+
@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=True)
33+
def test_download_allowed_repo_in_cache(mock_is_repo_in_cache, caplog):
34+
trust_repo = False
35+
repo_id = "dummy-repo"
36+
37+
with caplog.at_level(logging.DEBUG, logger="annif"):
38+
annif.hfh_util.check_is_download_allowed(trust_repo, repo_id)
39+
assert (
40+
'Download allowed from "dummy-repo" because repo is already in cache.'
41+
in caplog.text
42+
)
43+
44+
45+
@mock.patch("huggingface_hub.utils._cache_manager.HFCacheInfo")
46+
@mock.patch("huggingface_hub.scan_cache_dir") # Bypass CacheNotFound on CI/CD
47+
def test_download_not_allowed(mock_scan_cache_dir, mock_HFCacheInfo):
48+
trust_repo = False
49+
repo_id = "dummy-repo"
50+
mock_HFCacheInfo.return_value.repos = frozenset()
51+
52+
with pytest.raises(OperationFailedException) as excinfo:
53+
annif.hfh_util.check_is_download_allowed(trust_repo, repo_id)
54+
assert (
55+
str(excinfo.value)
56+
== 'Cannot download projects from untrusted repo "dummy-repo"'
57+
)
1458

1559

1660
def test_archive_dir(testdatadir):

0 commit comments

Comments
 (0)