diff --git a/src/huggingface_hub/cli/_errors.py b/src/huggingface_hub/cli/_errors.py index c2b24c310c..68561bb53a 100644 --- a/src/huggingface_hub/cli/_errors.py +++ b/src/huggingface_hub/cli/_errors.py @@ -26,24 +26,63 @@ ) -CLI_ERROR_MAPPINGS: dict[type[Exception], Callable[[Exception], str]] = { - BucketNotFoundError: lambda e: ( - "Bucket not found. Check the bucket id (namespace/name). If the bucket is private, make sure you are authenticated." - ), - RepositoryNotFoundError: lambda e: ( - "Repository not found. Check the `repo_id` and `repo_type` parameters. If the repo is private, make sure you are authenticated." - ), - RevisionNotFoundError: lambda e: "Revision not found. Check the `revision` parameter.", - GatedRepoError: lambda e: "Access denied. This repository requires approval.", - LocalTokenNotFoundError: lambda e: "Not logged in. Run 'hf auth login' first.", - RemoteEntryNotFoundError: lambda e: "File not found in repository.", - HfHubHTTPError: lambda e: str(e), - ValueError: lambda e: f"Invalid value. {e}", +def _format_repo_not_found(error: RepositoryNotFoundError) -> str: + label = error.repo_type.capitalize() if error.repo_type else "Repository" + if error.repo_id: + msg = f"{label} '{error.repo_id}' not found." + else: + msg = f"{label} not found." + msg += " If the repo is private, make sure you are authenticated." + return msg + + +def _format_gated_repo(error: GatedRepoError) -> str: + label = error.repo_type if error.repo_type else "repository" + if error.repo_id: + return f"Access denied. {label.capitalize()} '{error.repo_id}' requires approval." + return f"Access denied. This {label} requires approval." + + +def _format_bucket_not_found(error: BucketNotFoundError) -> str: + if error.bucket_id: + return f"Bucket '{error.bucket_id}' not found. If the bucket is private, make sure you are authenticated." + return "Bucket not found. Check the bucket id (namespace/name). If the bucket is private, make sure you are authenticated." + + +def _format_entry_not_found(error: RemoteEntryNotFoundError) -> str: + label = error.repo_type if error.repo_type else "repository" + url = str(error.response.url) if error.response else None + if error.repo_id: + msg = f"File not found in {label} '{error.repo_id}'." + else: + msg = f"File not found in {label}." + if url: + msg += f"\nURL: {url}" + return msg + + +def _format_revision_not_found(error: RevisionNotFoundError) -> str: + label = error.repo_type if error.repo_type else "repository" + if error.repo_id: + return f"Revision not found in {label} '{error.repo_id}'." + return f"Revision not found in {label}. Check the revision parameter." + + +CLI_ERROR_MAPPINGS: dict[type[Exception], Callable[[Exception], str]] = { # type: ignore + # GatedRepoError must come before RepositoryNotFoundError (it's a subclass). + GatedRepoError: _format_gated_repo, # type: ignore[dict-item] + BucketNotFoundError: _format_bucket_not_found, # type: ignore[dict-item] + RepositoryNotFoundError: _format_repo_not_found, # type: ignore[dict-item] + RevisionNotFoundError: _format_revision_not_found, # type: ignore[dict-item] + LocalTokenNotFoundError: lambda _: "Not logged in. Run 'hf auth login' first.", + RemoteEntryNotFoundError: _format_entry_not_found, # type: ignore[dict-item] + HfHubHTTPError: lambda error: str(error), + ValueError: lambda error: f"Invalid value. {error}", } -def format_known_exception(e: Exception) -> Optional[str]: +def format_known_exception(error: Exception) -> Optional[str]: for exc_type, formatter in CLI_ERROR_MAPPINGS.items(): - if isinstance(e, exc_type): - return formatter(e) + if isinstance(error, exc_type): + return formatter(error) return None diff --git a/src/huggingface_hub/errors.py b/src/huggingface_hub/errors.py index 6acaec8225..26a86e73d1 100644 --- a/src/huggingface_hub/errors.py +++ b/src/huggingface_hub/errors.py @@ -192,6 +192,10 @@ class BucketNotFoundError(HfHubHTTPError): """ Raised when trying to access a bucket that does not exist. + Attributes: + bucket_id (`str` or `None`): + The bucket id (namespace/name) that was not found, if it could be determined from the request URL. + Example: ```py @@ -206,6 +210,8 @@ class BucketNotFoundError(HfHubHTTPError): ``` """ + bucket_id: Optional[str] = None + # REPOSITORY ERRORS @@ -215,6 +221,12 @@ class RepositoryNotFoundError(HfHubHTTPError): Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does not have access to. + Attributes: + repo_id (`str` or `None`): + The repo id that was not found, if it could be determined from the request URL. + repo_type (`str` or `None`): + The repo type ("model", "dataset", or "space"), if it could be determined from the request URL. + Example: ```py @@ -230,6 +242,9 @@ class RepositoryNotFoundError(HfHubHTTPError): ``` """ + repo_id: Optional[str] = None + repo_type: Optional[str] = None + class GatedRepoError(RepositoryNotFoundError): """ @@ -279,6 +294,12 @@ class RevisionNotFoundError(HfHubHTTPError): Raised when trying to access a hf.co URL with a valid repository but an invalid revision. + Attributes: + repo_id (`str` or `None`): + The repo id, if it could be determined from the request URL. + repo_type (`str` or `None`): + The repo type ("model", "dataset", or "space"), if it could be determined from the request URL. + Example: ```py @@ -291,6 +312,9 @@ class RevisionNotFoundError(HfHubHTTPError): ``` """ + repo_id: Optional[str] = None + repo_type: Optional[str] = None + # ENTRY ERRORS class EntryNotFoundError(Exception): @@ -316,6 +340,12 @@ class RemoteEntryNotFoundError(HfHubHTTPError, EntryNotFoundError): Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename. + Attributes: + repo_id (`str` or `None`): + The repo id, if it could be determined from the request URL. + repo_type (`str` or `None`): + The repo type ("model", "dataset", or "space"), if it could be determined from the request URL. + Example: ```py @@ -328,6 +358,9 @@ class RemoteEntryNotFoundError(HfHubHTTPError, EntryNotFoundError): ``` """ + repo_id: Optional[str] = None + repo_type: Optional[str] = None + class LocalEntryNotFoundError(FileNotFoundError, EntryNotFoundError): """ diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 38f606fd8e..79a694a637 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -25,7 +25,7 @@ from contextlib import contextmanager from dataclasses import dataclass from shlex import quote -from typing import Any, Callable, Generator, Mapping, Optional, Union +from typing import Any, Callable, Generator, Mapping, Optional, TypeVar, Union from urllib.parse import urlparse import httpx @@ -169,6 +169,47 @@ def parse_ratelimit_headers(headers: Mapping[str, str]) -> Optional[RateLimitInf flags=re.VERBOSE, ) +# Regex to extract repo_type and repo_id from API URLs. +# Captures: group(1) = repo_type plural (models/datasets/spaces), group(2) = first path segment, group(3) = optional second segment. +_REPO_ID_FROM_URL_REGEX = re.compile(r"^https?://[^/]+/api/(models|datasets|spaces)/([^/]+)(?:/([^/]+))?") + +# Regex to extract bucket_id (namespace/name) from bucket API URLs. +_BUCKET_ID_FROM_URL_REGEX = re.compile(r"^https?://[^/]+/api/buckets/([^/]+/[^/]+)") + +# Sub-paths that follow a repo_id in API URLs (not part of the repo name). +_REPO_URL_SUBPATHS = {"resolve", "tree", "blob", "raw", "refs", "commit", "discussions", "settings", "revision"} + + +def _parse_repo_info_from_url(url: str) -> tuple[Optional[str], Optional[str]]: + """Extract (repo_type, repo_id) from an API URL. + + Returns canonical repo_type values: "model", "dataset", "space" (or None). + + Examples: + >>> _parse_repo_info_from_url("https://huggingface.co/api/models/user/repo") + ("model", "user/repo") + >>> _parse_repo_info_from_url("https://huggingface.co/api/datasets/user/repo/resolve/main/data.csv") + ("dataset", "user/repo") + >>> _parse_repo_info_from_url("https://huggingface.co/api/models/bert-base-cased/resolve/main/config.json") + ("model", "bert-base-cased") + """ + match = _REPO_ID_FROM_URL_REGEX.search(url) + if not match: + return None, None + repo_type = constants.REPO_TYPES_MAPPING.get(match.group(1)) + first, second = match.group(2), match.group(3) + if second and second not in _REPO_URL_SUBPATHS: + repo_id = f"{first}/{second}" + else: + repo_id = first + return repo_type, repo_id + + +def _parse_bucket_id_from_url(url: str) -> Optional[str]: + """Extract bucket_id (namespace/name) from a bucket API URL.""" + match = _BUCKET_ID_FROM_URL_REGEX.search(url) + return match.group(1) if match else None + def hf_request_event_hook(request: httpx.Request) -> None: """ @@ -725,19 +766,34 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = error_code = response.headers.get("X-Error-Code") error_message = response.headers.get("X-Error-Message") + # Parse repo info from request URL (used to enrich errors below) + request_url = ( + str(response.request.url) if response.request is not None and response.request.url is not None else None + ) + repo_type, repo_id = _parse_repo_info_from_url(request_url) if request_url else (None, None) + if error_code == "RevisionNotFound": message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." - raise _format(RevisionNotFoundError, message, response) from e + revision_err = _format(RevisionNotFoundError, message, response) + revision_err.repo_type = repo_type + revision_err.repo_id = repo_id + raise revision_err from e elif error_code == "EntryNotFound": message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." - raise _format(RemoteEntryNotFoundError, message, response) from e + entry_err = _format(RemoteEntryNotFoundError, message, response) + entry_err.repo_type = repo_type + entry_err.repo_id = repo_id + raise entry_err from e elif error_code == "GatedRepo": message = ( f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." ) - raise _format(GatedRepoError, message, response) from e + gated_err = _format(GatedRepoError, message, response) + gated_err.repo_type = repo_type + gated_err.repo_id = repo_id + raise gated_err from e elif error_message == "Access to this resource is disabled.": message = ( @@ -751,9 +807,8 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = elif ( error_code == "RepoNotFound" - and response.request is not None - and response.request.url is not None - and BUCKET_API_REGEX.search(str(response.request.url)) is not None + and request_url is not None + and BUCKET_API_REGEX.search(request_url) is not None ): message = ( f"{response.status_code} Client Error." @@ -762,14 +817,15 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = + "\nPlease make sure you specified the correct bucket id (namespace/name)." + "\nIf the bucket is private, make sure you are authenticated." ) - raise _format(BucketNotFoundError, message, response) from e + bucket_err = _format(BucketNotFoundError, message, response) + bucket_err.bucket_id = _parse_bucket_id_from_url(request_url) + raise bucket_err from e elif error_code == "RepoNotFound" or ( response.status_code == 401 and error_message != "Invalid credentials in Authorization header" - and response.request is not None - and response.request.url is not None - and REPO_API_REGEX.search(str(response.request.url)) is not None + and request_url is not None + and REPO_API_REGEX.search(request_url) is not None ): # 401 is misleading as it is returned for: # - private and gated repos if user is not authenticated @@ -785,7 +841,10 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = " make sure you are authenticated. For more details, see" " https://huggingface.co/docs/huggingface_hub/authentication" ) - raise _format(RepositoryNotFoundError, message, response) from e + repo_err = _format(RepositoryNotFoundError, message, response) + repo_err.repo_type = repo_type + repo_err.repo_id = repo_id + raise repo_err from e elif response.status_code == 400: message = ( @@ -857,7 +916,10 @@ def _warn_on_warning_headers(response: httpx.Response) -> None: logger.warning(message) -def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError: +_HfHubHTTPErrorT = TypeVar("_HfHubHTTPErrorT", bound=HfHubHTTPError) + + +def _format(error_type: type[_HfHubHTTPErrorT], custom_message: str, response: httpx.Response) -> _HfHubHTTPErrorT: server_errors = [] # Retrieve server error from header diff --git a/tests/test_cli_errors.py b/tests/test_cli_errors.py new file mode 100644 index 0000000000..0163268132 --- /dev/null +++ b/tests/test_cli_errors.py @@ -0,0 +1,116 @@ +"""Tests for CLI error formatting utilities.""" + +from unittest.mock import Mock + +import httpx + +from huggingface_hub.cli._errors import ( + _format_bucket_not_found, + _format_entry_not_found, + _format_gated_repo, + _format_repo_not_found, + _format_revision_not_found, +) +from huggingface_hub.errors import ( + BucketNotFoundError, + GatedRepoError, + RemoteEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) + + +def _make_error(cls, **attrs): + """Helper to create an HfHubHTTPError subclass with custom attributes.""" + response = Mock(spec=httpx.Response) + response.headers = httpx.Headers({}) + response.request = Mock(spec=httpx.Request) + err = cls("test", response=response) + for key, value in attrs.items(): + setattr(err, key, value) + return err + + +class TestFormatRepoNotFound: + def test_with_repo_id_and_type(self): + err = _make_error(RepositoryNotFoundError, repo_id="user/repo", repo_type="model") + assert ( + _format_repo_not_found(err) + == "Model 'user/repo' not found. If the repo is private, make sure you are authenticated." + ) + + def test_with_repo_id_dataset(self): + err = _make_error(RepositoryNotFoundError, repo_id="user/data", repo_type="dataset") + assert "Dataset 'user/data' not found." in _format_repo_not_found(err) + + def test_with_repo_id_no_type(self): + err = _make_error(RepositoryNotFoundError, repo_id="user/repo", repo_type=None) + assert "Repository 'user/repo' not found." in _format_repo_not_found(err) + + def test_without_repo_id(self): + err = _make_error(RepositoryNotFoundError, repo_id=None, repo_type=None) + msg = _format_repo_not_found(err) + assert "Repository not found." in msg + assert "authenticated" in msg + + +class TestFormatGatedRepo: + def test_with_repo_id(self): + err = _make_error(GatedRepoError, repo_id="user/gated") + assert _format_gated_repo(err) == "Access denied. Repository 'user/gated' requires approval." + + def test_without_repo_id(self): + err = _make_error(GatedRepoError, repo_id=None) + assert _format_gated_repo(err) == "Access denied. This repository requires approval." + + +class TestFormatBucketNotFound: + def test_with_bucket_id(self): + err = _make_error(BucketNotFoundError, bucket_id="ns/bucket") + msg = _format_bucket_not_found(err) + assert "Bucket 'ns/bucket' not found." in msg + assert "authenticated" in msg + + def test_without_bucket_id(self): + err = _make_error(BucketNotFoundError, bucket_id=None) + msg = _format_bucket_not_found(err) + assert "Bucket not found." in msg + assert "namespace/name" in msg + + +class TestFormatEntryNotFound: + def test_with_repo_id_and_type(self): + err = _make_error(RemoteEntryNotFoundError, repo_id="user/repo", repo_type="dataset") + msg = _format_entry_not_found(err) + assert "File not found in dataset 'user/repo'." in msg + + def test_with_repo_id_no_type(self): + err = _make_error(RemoteEntryNotFoundError, repo_id="user/repo", repo_type=None) + msg = _format_entry_not_found(err) + assert "File not found in repository 'user/repo'." in msg + + def test_without_repo_id(self): + err = _make_error(RemoteEntryNotFoundError, repo_id=None, repo_type=None) + msg = _format_entry_not_found(err) + assert "File not found in repository." in msg + + def test_includes_url(self): + err = _make_error(RemoteEntryNotFoundError, repo_id="user/repo", repo_type="model") + err.response.url = "https://huggingface.co/api/models/user/repo/resolve/main/missing.bin" + msg = _format_entry_not_found(err) + assert "File not found in model 'user/repo'." in msg + assert "URL: https://huggingface.co/api/models/user/repo/resolve/main/missing.bin" in msg + + +class TestFormatRevisionNotFound: + def test_with_repo_id(self): + err = _make_error(RevisionNotFoundError, repo_id="user/repo", repo_type=None) + assert _format_revision_not_found(err) == "Revision not found in repository 'user/repo'." + + def test_with_repo_id_and_type(self): + err = _make_error(RevisionNotFoundError, repo_id="user/repo", repo_type="dataset") + assert _format_revision_not_found(err) == "Revision not found in dataset 'user/repo'." + + def test_without_repo_id(self): + err = _make_error(RevisionNotFoundError, repo_id=None, repo_type=None) + assert _format_revision_not_found(err) == "Revision not found in repository. Check the revision parameter." diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index 4a87dd20ad..fc81e2f8d5 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -17,6 +17,8 @@ _WARNED_TOPICS, RateLimitInfo, _adjust_range_header, + _parse_bucket_id_from_url, + _parse_repo_info_from_url, _warn_on_warning_headers, default_client_factory, fix_hf_endpoint_in_url, @@ -638,3 +640,65 @@ def test_warn_on_warning_headers(self, caplog): assert len(warnings) == 1 assert warnings == ["Another warning."] assert "Topic4" in _WARNED_TOPICS + + +class TestParseRepoInfoFromUrl: + def test_api_model_with_namespace(self): + assert _parse_repo_info_from_url("https://huggingface.co/api/models/user/repo") == ("model", "user/repo") + + def test_api_dataset_with_namespace(self): + assert _parse_repo_info_from_url("https://huggingface.co/api/datasets/user/repo") == ("dataset", "user/repo") + + def test_api_space_with_namespace(self): + assert _parse_repo_info_from_url("https://huggingface.co/api/spaces/user/repo") == ("space", "user/repo") + + def test_api_model_without_namespace(self): + assert _parse_repo_info_from_url("https://huggingface.co/api/models/bert-base-cased") == ( + "model", + "bert-base-cased", + ) + + def test_api_model_with_resolve_subpath(self): + repo_type, repo_id = _parse_repo_info_from_url( + "https://huggingface.co/api/models/user/repo/resolve/main/config.json" + ) + assert repo_type == "model" + assert repo_id == "user/repo" + + def test_api_dataset_with_tree_subpath(self): + repo_type, repo_id = _parse_repo_info_from_url("https://huggingface.co/api/datasets/user/repo/tree/main") + assert repo_type == "dataset" + assert repo_id == "user/repo" + + def test_api_model_single_name_with_subpath(self): + repo_type, repo_id = _parse_repo_info_from_url( + "https://huggingface.co/api/models/bert-base-cased/resolve/main/config.json" + ) + assert repo_type == "model" + assert repo_id == "bert-base-cased" + + def test_non_matching_url(self): + assert _parse_repo_info_from_url("https://huggingface.co/some/other/path") == (None, None) + + def test_staging_url(self): + assert _parse_repo_info_from_url("https://hub-ci.huggingface.co/api/models/user/repo") == ( + "model", + "user/repo", + ) + + +class TestParseBucketIdFromUrl: + def test_bucket_url(self): + assert _parse_bucket_id_from_url("https://huggingface.co/api/buckets/namespace/name") == "namespace/name" + + def test_bucket_url_with_subpath(self): + assert ( + _parse_bucket_id_from_url("https://huggingface.co/api/buckets/namespace/name/tree/prefix") + == "namespace/name" + ) + + def test_non_bucket_url(self): + assert _parse_bucket_id_from_url("https://huggingface.co/api/models/user/repo") is None + + def test_http_url(self): + assert _parse_bucket_id_from_url("http://localhost:8080/api/buckets/ns/name") == "ns/name"