Skip to content

Commit

Permalink
Expose RepoUrl info in CommitInfo object (#2487)
Browse files Browse the repository at this point in the history
* Expose RepoUrl info in CommitInfo object

* add test

* fix test

* fix tests
  • Loading branch information
Wauplin authored Aug 26, 2024
1 parent 6438044 commit bd209e7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
10 changes: 10 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ class CommitInfo(str):
`create_pr=True` is passed. Can be passed as `discussion_num` in
[`get_discussion_details`]. Example: `1`.
repo_url (`RepoUrl`):
Repo URL of the commit containing info like repo_id, repo_type, etc.
_url (`str`, *optional*):
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
Expand All @@ -402,6 +405,9 @@ class CommitInfo(str):
oid: str
pr_url: Optional[str] = None

# Computed from `commit_url` in `__post_init__`
repo_url: RepoUrl = field(init=False)

# Computed from `pr_url` in `__post_init__`
pr_revision: Optional[str] = field(init=False)
pr_num: Optional[str] = field(init=False)
Expand All @@ -417,6 +423,10 @@ def __post_init__(self):
See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing.
"""
# Repo info
self.repo_url = RepoUrl(self.commit_url.split("/commit/")[0])

# PR info
if self.pr_url is not None:
self.pr_revision = _parse_revision_from_pr_url(self.pr_url)
self.pr_num = int(self.pr_revision.split("/")[-1])
Expand Down
3 changes: 1 addition & 2 deletions src/huggingface_hub/utils/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None)
message = (
f"\n\n{response.status_code} Forbidden: {error_message}."
+ f"\nCannot access content at: {response.url}."
+ "\nIf you are trying to create or update content, "
+ "make sure your token has the correct permissions."
+ "\nMake sure your token has the correct permissions."
)
raise HfHubHTTPError(message, response=response) from e

Expand Down
27 changes: 22 additions & 5 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT,
DUMMY_MODEL_ID,
DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT,
ENDPOINT_PRODUCTION,
SAMPLE_DATASET_IDENTIFIER,
repo_name,
require_git_lfs,
Expand Down Expand Up @@ -2480,6 +2481,7 @@ def setUp(self) -> None:
self.api.list_repo_files = self.repo_files_mock

self.create_commit_mock = Mock()
self.create_commit_mock.return_value.commit_url = f"{ENDPOINT_STAGING}/username/repo_id/commit/dummy_sha"
self.create_commit_mock.return_value.pr_url = None
self.api.create_commit = self.create_commit_mock

Expand Down Expand Up @@ -2698,7 +2700,7 @@ def test_repo_type_and_id_from_hf_id_on_correct_values(self):

for key, value in possible_values.items():
self.assertEqual(
repo_type_and_id_from_hf_id(key, hub_url="https://huggingface.co"),
repo_type_and_id_from_hf_id(key, hub_url=ENDPOINT_PRODUCTION),
tuple(value),
)

Expand All @@ -2711,7 +2713,7 @@ def test_repo_type_and_id_from_hf_id_on_wrong_values(self):
"spaeces/user/id", # with typo in repo type
]:
with self.assertRaises(ValueError):
repo_type_and_id_from_hf_id(hub_id, hub_url="https://huggingface.co")
repo_type_and_id_from_hf_id(hub_id, hub_url=ENDPOINT_PRODUCTION)


class HfApiDiscussionsTest(HfApiCommonTest):
Expand Down Expand Up @@ -3041,12 +3043,13 @@ def greet(name):
iface.launch()
""".encode()

@with_production_testing
def setUp(self):
super().setUp()

# If generating new VCR => use personal token and REMOVE IT from the VCR
self.repo_id = "user/tmp_test_space" # no need to be unique as it's a VCRed test
self.api = HfApi(token="hf_fake_token", endpoint="https://huggingface.co")
self.api = HfApi(token="hf_fake_token", endpoint=ENDPOINT_PRODUCTION)

# Create a Space
self.api.create_repo(repo_id=self.repo_id, repo_type="space", space_sdk="gradio", private=True)
Expand Down Expand Up @@ -3132,6 +3135,7 @@ def test_static_space_runtime(self) -> None:
runtime = self.api.get_space_runtime("victor/static-space")
self.assertIsInstance(runtime.raw, dict)

@with_production_testing
def test_pause_and_restart_space(self) -> None:
# Upload a fake app.py file
self.api.upload_file(path_or_fileobj=b"", path_in_repo="app.py", repo_id=self.repo_id, repo_type="space")
Expand Down Expand Up @@ -3671,7 +3675,7 @@ def test_user_agent_is_overwritten(self, mock_build_hf_headers: Mock) -> None:
self.assertEqual(mock_build_hf_headers.call_args[1]["user_agent"], {"A": "B"})


@patch("huggingface_hub.constants.ENDPOINT", "https://huggingface.co")
@patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION)
class RepoUrlTest(unittest.TestCase):
def test_repo_url_class(self):
url = RepoUrl("https://huggingface.co/gpt2")
Expand All @@ -3697,7 +3701,7 @@ def test_repo_url_class(self):
def test_repo_url_endpoint(self):
# Implicit endpoint
url = RepoUrl("https://huggingface.co/gpt2")
self.assertEqual(url.endpoint, "https://huggingface.co")
self.assertEqual(url.endpoint, ENDPOINT_PRODUCTION)

# Explicit endpoint
url = RepoUrl("https://example.com/gpt2", endpoint="https://example.com")
Expand Down Expand Up @@ -3751,6 +3755,19 @@ def test_repo_url_canonical_dataset(self):
self.assertEqual(url.repo_id, "squad")
self.assertEqual(url.repo_type, "dataset")

def test_repo_url_in_commit_info(self):
info = CommitInfo(
commit_url="https://huggingface.co/Wauplin/test-repo-id-mixin/commit/52d172a8b276e529d5260d6f3f76c85be5889dee",
commit_message="Dummy message",
commit_description="Dummy description",
oid="52d172a8b276e529d5260d6f3f76c85be5889dee",
pr_url=None,
)
assert isinstance(info.repo_url, RepoUrl)
assert info.repo_url.endpoint == "https://huggingface.co"
assert info.repo_url.repo_id == "Wauplin/test-repo-id-mixin"
assert info.repo_url.repo_type == "model"


class HfApiDuplicateSpaceTest(HfApiCommonTest):
@unittest.skip("Duplicating Space doesn't work on staging.")
Expand Down

0 comments on commit bd209e7

Please sign in to comment.