diff --git a/src/huggingface_hub/_login.py b/src/huggingface_hub/_login.py index 7700ab5b23..90a130fa61 100644 --- a/src/huggingface_hub/_login.py +++ b/src/huggingface_hub/_login.py @@ -19,6 +19,8 @@ from pathlib import Path from typing import Optional +import typer + from . import constants from .utils import ( ANSI, @@ -244,8 +246,6 @@ def interpreter_login(*, skip_if_logged_in: bool = False) -> None: logger.info("User is already logged in.") return - from .cli.cache import _ask_for_confirmation_no_tui - print(_HF_LOGO_ASCII) if get_token() is not None: logger.info( @@ -261,7 +261,7 @@ def interpreter_login(*, skip_if_logged_in: bool = False) -> None: if os.name == "nt": logger.info("Token can be pasted using 'Right-Click'.") token = getpass("Enter your token (input will not be visible): ") - add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?") + add_to_git_credential = typer.confirm("Add token as git credential?") _login(token=token, add_to_git_credential=add_to_git_credential) diff --git a/src/huggingface_hub/cli/cache.py b/src/huggingface_hub/cli/cache.py index 35f7540821..939c991833 100644 --- a/src/huggingface_hub/cli/cache.py +++ b/src/huggingface_hub/cli/cache.py @@ -12,368 +12,625 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Contains the 'hf cache' command group with 'scan' and 'delete' subcommands.""" +"""Contains the 'hf cache' command group with cache management subcommands.""" -import os +import csv +import json +import re +import sys import time +from collections import defaultdict +from dataclasses import dataclass from enum import Enum -from functools import wraps -from tempfile import mkstemp -from typing import Annotated, Any, Callable, Iterable, Optional, Union +from typing import Annotated, Any, Callable, Dict, List, Mapping, Optional, Tuple import typer -from ..utils import ANSI, CachedRepoInfo, CachedRevisionInfo, CacheNotFound, HFCacheInfo, scan_cache_dir, tabulate +from ..utils import ( + ANSI, + CachedRepoInfo, + CachedRevisionInfo, + CacheNotFound, + HFCacheInfo, + _format_size, + scan_cache_dir, + tabulate, +) +from ..utils._parsing import parse_duration, parse_size from ._cli_utils import typer_factory -# --- DELETE helpers (from delete_cache.py) --- -try: - from InquirerPy import inquirer - from InquirerPy.base.control import Choice - from InquirerPy.separator import Separator +cache_cli = typer_factory(help="Manage local cache directory.") + + +#### Cache helper utilities + + +class OutputFormat(str, Enum): + table = "table" + json = "json" + csv = "csv" + + +@dataclass(frozen=True) +class _DeletionResolution: + revisions: frozenset[str] + selected: dict[CachedRepoInfo, frozenset[CachedRevisionInfo]] + missing: tuple[str, ...] + + +_FILTER_PATTERN = re.compile(r"^(?P[a-zA-Z_]+)\s*(?P==|!=|>=|<=|>|<|=)\s*(?P.+)$") +_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<="} +_FILTER_KEYS = {"accessed", "modified", "refs", "size", "type"} + + +@dataclass(frozen=True) +class CacheDeletionCounts: + """Simple counters summarizing cache deletions for CLI messaging.""" + + repo_count: int + partial_revision_count: int + total_revision_count: int + + +CacheEntry = Tuple[CachedRepoInfo, Optional[CachedRevisionInfo]] +RepoRefsMap = Dict[CachedRepoInfo, frozenset[str]] + - _inquirer_py_available = True -except ImportError: - _inquirer_py_available = False +def summarize_deletions( + selected_by_repo: Mapping[CachedRepoInfo, frozenset[CachedRevisionInfo]], +) -> CacheDeletionCounts: + """Summarize deletions across repositories.""" + repo_count = 0 + total_revisions = 0 + revisions_in_full_repos = 0 -_CANCEL_DELETION_STR = "CANCEL_DELETION" + for repo, revisions in selected_by_repo.items(): + total_revisions += len(revisions) + if len(revisions) == len(repo.revisions): + repo_count += 1 + revisions_in_full_repos += len(revisions) + partial_revision_count = total_revisions - revisions_in_full_repos + return CacheDeletionCounts(repo_count, partial_revision_count, total_revisions) -class SortingOption(str, Enum): - alphabetical = "alphabetical" - lastUpdated = "lastUpdated" - lastUsed = "lastUsed" - size = "size" +def print_cache_selected_revisions(selected_by_repo: Mapping[CachedRepoInfo, frozenset[CachedRevisionInfo]]) -> None: + """Pretty-print selected cache revisions during confirmation prompts.""" + for repo in sorted(selected_by_repo.keys(), key=lambda repo: (repo.repo_type, repo.repo_id.lower())): + repo_key = f"{repo.repo_type}/{repo.repo_id}" + revisions = sorted(selected_by_repo[repo], key=lambda rev: rev.commit_hash) + if len(revisions) == len(repo.revisions): + print(f" - {repo_key} (entire repo)") + continue -def require_inquirer_py(fn: Callable) -> Callable: - @wraps(fn) - def _inner(*args, **kwargs): - if not _inquirer_py_available: - raise ImportError( - "The 'cache delete' command requires extra dependencies for the TUI.\n" - "Please run 'pip install \"huggingface_hub[cli]\"' to install them.\n" - "Otherwise, disable TUI using the '--disable-tui' flag." + print(f" - {repo_key}:") + for revision in revisions: + refs = " ".join(sorted(revision.refs)) or "(detached)" + print(f" {revision.commit_hash} [{refs}] {revision.size_on_disk_str}") + + +def build_cache_index( + hf_cache_info: HFCacheInfo, +) -> Tuple[ + Dict[str, CachedRepoInfo], + Dict[str, Tuple[CachedRepoInfo, CachedRevisionInfo]], +]: + """Create lookup tables so CLI commands can resolve repo ids and revisions quickly.""" + repo_lookup: dict[str, CachedRepoInfo] = {} + revision_lookup: dict[str, tuple[CachedRepoInfo, CachedRevisionInfo]] = {} + for repo in hf_cache_info.repos: + repo_key = repo.cache_id.lower() + repo_lookup[repo_key] = repo + for revision in repo.revisions: + revision_lookup[revision.commit_hash.lower()] = (repo, revision) + return repo_lookup, revision_lookup + + +def collect_cache_entries( + hf_cache_info: HFCacheInfo, *, include_revisions: bool +) -> Tuple[List[CacheEntry], RepoRefsMap]: + """Flatten cache metadata into rows consumed by `hf cache ls`.""" + entries: List[CacheEntry] = [] + repo_refs_map: RepoRefsMap = {} + sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: (repo.repo_type, repo.repo_id.lower())) + for repo in sorted_repos: + repo_refs_map[repo] = frozenset({ref for revision in repo.revisions for ref in revision.refs}) + if include_revisions: + for revision in sorted(repo.revisions, key=lambda rev: rev.commit_hash): + entries.append((repo, revision)) + else: + entries.append((repo, None)) + if include_revisions: + entries.sort( + key=lambda entry: ( + entry[0].cache_id, + entry[1].commit_hash if entry[1] is not None else "", ) - return fn(*args, **kwargs) + ) + else: + entries.sort(key=lambda entry: entry[0].cache_id) + return entries, repo_refs_map + + +def compile_cache_filter( + expr: str, repo_refs_map: RepoRefsMap +) -> Callable[[CachedRepoInfo, Optional[CachedRevisionInfo], float], bool]: + """Convert a `hf cache ls` filter expression into the yes/no test we apply to each cache entry before displaying it.""" + match = _FILTER_PATTERN.match(expr.strip()) + if not match: + raise ValueError(f"Invalid filter expression: '{expr}'.") + + key = match.group("key").lower() + op = match.group("op") + value_raw = match.group("value").strip() + + if op not in _ALLOWED_OPERATORS: + raise ValueError(f"Unsupported operator '{op}' in filter '{expr}'. Must be one of {list(_ALLOWED_OPERATORS)}.") + + if key not in _FILTER_KEYS: + raise ValueError(f"Unsupported filter key '{key}' in '{expr}'. Must be one of {list(_FILTER_KEYS)}.") + # at this point we know that key is in `_FILTER_KEYS` + if key == "size": + size_threshold = parse_size(value_raw) + return lambda repo, revision, _: _compare_numeric( + revision.size_on_disk if revision is not None else repo.size_on_disk, + op, + size_threshold, + ) + + if key in {"modified", "accessed"}: + seconds = parse_duration(value_raw.strip()) + + def _time_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], now: float) -> bool: + timestamp = ( + repo.last_accessed + if key == "accessed" + else revision.last_modified + if revision is not None + else repo.last_modified + ) + if timestamp is None: + return False + return _compare_numeric(now - timestamp, op, seconds) + + return _time_filter + + if key == "type": + expected = value_raw.lower() + + if op != "=": + raise ValueError(f"Only '=' is supported for 'type' filters. Got '{op}'.") + + def _type_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], _: float) -> bool: + return repo.repo_type.lower() == expected + + return _type_filter + + else: # key == "refs" + if op != "=": + raise ValueError(f"Only '=' is supported for 'refs' filters. Got {op}.") + + def _refs_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], _: float) -> bool: + refs = revision.refs if revision is not None else repo_refs_map.get(repo, frozenset()) + return value_raw.lower() in [ref.lower() for ref in refs] + + return _refs_filter + + +def _build_cache_export_payload( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> List[Dict[str, Any]]: + """Normalize cache entries into serializable records for JSON/CSV exports.""" + payload: List[Dict[str, Any]] = [] + for repo, revision in entries: + if include_revisions: + if revision is None: + continue + record: Dict[str, Any] = { + "repo_id": repo.repo_id, + "repo_type": repo.repo_type, + "revision": revision.commit_hash, + "snapshot_path": str(revision.snapshot_path), + "size_on_disk": revision.size_on_disk, + "last_accessed": repo.last_accessed, + "last_modified": revision.last_modified, + "refs": sorted(revision.refs), + } + else: + record = { + "repo_id": repo.repo_id, + "repo_type": repo.repo_type, + "size_on_disk": repo.size_on_disk, + "last_accessed": repo.last_accessed, + "last_modified": repo.last_modified, + "refs": sorted(repo_refs_map.get(repo, frozenset())), + } + payload.append(record) + return payload + + +def print_cache_entries_table( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> None: + """Render cache entries as a table and show a human-readable summary.""" + if not entries: + message = "No cached revisions found." if include_revisions else "No cached repositories found." + print(message) + return + table_rows: List[List[str]] + if include_revisions: + headers = ["ID", "REVISION", "SIZE", "LAST_MODIFIED", "REFS"] + table_rows = [ + [ + repo.cache_id, + revision.commit_hash, + revision.size_on_disk_str.rjust(8), + revision.last_modified_str, + " ".join(sorted(revision.refs)), + ] + for repo, revision in entries + if revision is not None + ] + else: + headers = ["ID", "SIZE", "LAST_ACCESSED", "LAST_MODIFIED", "REFS"] + table_rows = [ + [ + repo.cache_id, + repo.size_on_disk_str.rjust(8), + repo.last_accessed_str or "", + repo.last_modified_str, + " ".join(sorted(repo_refs_map.get(repo, frozenset()))), + ] + for repo, _ in entries + ] + + print(tabulate(table_rows, headers=headers)) # type: ignore[arg-type] + + unique_repos = {repo for repo, _ in entries} + repo_count = len(unique_repos) + if include_revisions: + revision_count = sum(1 for _, revision in entries if revision is not None) + total_size = sum(revision.size_on_disk for _, revision in entries if revision is not None) + else: + revision_count = sum(len(repo.revisions) for repo in unique_repos) + total_size = sum(repo.size_on_disk for repo in unique_repos) - return _inner + summary = f"\nFound {repo_count} repo(s) for a total of {revision_count} revision(s) and {_format_size(total_size)} on disk." + print(ANSI.bold(summary)) -cache_cli = typer_factory(help="Manage local cache directory.") +def print_cache_entries_json( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> None: + """Dump cache entries as JSON for scripting or automation.""" + payload = _build_cache_export_payload(entries, include_revisions=include_revisions, repo_refs_map=repo_refs_map) + json.dump(payload, sys.stdout, indent=2) + sys.stdout.write("\n") + + +def print_cache_entries_csv(entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap) -> None: + """Export cache entries as CSV rows with the shared payload format.""" + records = _build_cache_export_payload(entries, include_revisions=include_revisions, repo_refs_map=repo_refs_map) + writer = csv.writer(sys.stdout) + + if include_revisions: + headers = [ + "repo_id", + "repo_type", + "revision", + "snapshot_path", + "size_on_disk", + "last_accessed", + "last_modified", + "refs", + ] + else: + headers = ["repo_id", "repo_type", "size_on_disk", "last_accessed", "last_modified", "refs"] + + writer.writerow(headers) + if not records: + return -@cache_cli.command("scan", help="Scan the cache directory") -def cache_scan( - dir: Annotated[ + for record in records: + refs = record["refs"] + if include_revisions: + row = [ + record.get("repo_id", ""), + record.get("repo_type", ""), + record.get("revision", ""), + record.get("snapshot_path", ""), + record.get("size_on_disk"), + record.get("last_accessed"), + record.get("last_modified"), + " ".join(refs) if refs else "", + ] + else: + row = [ + record.get("repo_id", ""), + record.get("repo_type", ""), + record.get("size_on_disk"), + record.get("last_accessed"), + record.get("last_modified"), + " ".join(refs) if refs else "", + ] + writer.writerow(row) + + +def _compare_numeric(left: Optional[float], op: str, right: float) -> bool: + """Evaluate numeric comparisons for filters.""" + if left is None: + return False + + comparisons = { + "=": left == right, + "!=": left != right, + ">": left > right, + "<": left < right, + ">=": left >= right, + "<=": left <= right, + } + + if op not in comparisons: + raise ValueError(f"Unsupported numeric comparison operator: {op}") + + return comparisons[op] + + +def _resolve_deletion_targets(hf_cache_info: HFCacheInfo, targets: list[str]) -> _DeletionResolution: + """Resolve the deletion targets into a deletion resolution.""" + repo_lookup, revision_lookup = build_cache_index(hf_cache_info) + + selected: dict[CachedRepoInfo, set[CachedRevisionInfo]] = defaultdict(set) + revisions: set[str] = set() + missing: list[str] = [] + + for raw_target in targets: + target = raw_target.strip() + if not target: + continue + lowered = target.lower() + + if re.fullmatch(r"[0-9a-fA-F]{40}", lowered): + match = revision_lookup.get(lowered) + if match is None: + missing.append(raw_target) + continue + repo, revision = match + selected[repo].add(revision) + revisions.add(revision.commit_hash) + continue + + matched_repo = repo_lookup.get(lowered) + if matched_repo is None: + missing.append(raw_target) + continue + + for revision in matched_repo.revisions: + selected[matched_repo].add(revision) + revisions.add(revision.commit_hash) + + frozen_selected = {repo: frozenset(revs) for repo, revs in selected.items()} + return _DeletionResolution( + revisions=frozenset(revisions), + selected=frozen_selected, + missing=tuple(missing), + ) + + +#### Cache CLI commands + + +@cache_cli.command() +def ls( + cache_dir: Annotated[ Optional[str], typer.Option( help="Cache directory to scan (defaults to Hugging Face cache).", ), ] = None, - verbose: Annotated[ - int, + revisions: Annotated[ + bool, + typer.Option( + help="Include revisions in the output instead of aggregated repositories.", + ), + ] = False, + filter: Annotated[ + Optional[list[str]], + typer.Option( + "-f", + "--filter", + help="Filter entries (e.g. 'size>1GB', 'type=model', 'accessed>7d'). Can be used multiple times.", + ), + ] = None, + format: Annotated[ + OutputFormat, + typer.Option( + help="Output format.", + ), + ] = OutputFormat.table, + quiet: Annotated[ + bool, typer.Option( - "-v", - "--verbose", - count=True, - help="Increase verbosity (-v, -vv, -vvv).", + "-q", + "--quiet", + help="Print only IDs (repo IDs or revision hashes).", ), - ] = 0, + ] = False, ) -> None: + """List cached repositories or revisions.""" try: - t0 = time.time() - hf_cache_info = scan_cache_dir(dir) - t1 = time.time() + hf_cache_info = scan_cache_dir(cache_dir) except CacheNotFound as exc: print(f"Cache directory not found: {str(exc.cache_dir)}") + raise typer.Exit(code=1) + + filters = filter or [] + + entries, repo_refs_map = collect_cache_entries(hf_cache_info, include_revisions=revisions) + try: + filter_fns = [compile_cache_filter(expr, repo_refs_map) for expr in filters] + except ValueError as exc: + raise typer.BadParameter(str(exc)) from exc + + now = time.time() + for fn in filter_fns: + entries = [entry for entry in entries if fn(entry[0], entry[1], now)] + + if quiet: + for repo, revision in entries: + print(revision.commit_hash if revision is not None else repo.cache_id) return - print(get_table(hf_cache_info, verbosity=verbose)) - print( - f"\nDone in {round(t1 - t0, 1)}s. Scanned {len(hf_cache_info.repos)} repo(s)" - f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}." - ) - if len(hf_cache_info.warnings) > 0: - message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning." - if verbose >= 3: - print(ANSI.gray(message)) - for warning in hf_cache_info.warnings: - print(ANSI.gray(str(warning))) - else: - print(ANSI.gray(message + " Use -vvv to print details.")) + formatters = { + OutputFormat.table: print_cache_entries_table, + OutputFormat.json: print_cache_entries_json, + OutputFormat.csv: print_cache_entries_csv, + } + return formatters[format](entries, include_revisions=revisions, repo_refs_map=repo_refs_map) -@cache_cli.command("delete", help="Delete revisions from the cache directory") -def cache_delete( - dir: Annotated[ + +@cache_cli.command() +def rm( + targets: Annotated[ + list[str], + typer.Argument( + help="One or more repo IDs (e.g. model/bert-base-uncased) or revision hashes to delete.", + ), + ], + cache_dir: Annotated[ Optional[str], typer.Option( - help="Cache directory (defaults to Hugging Face cache).", + help="Cache directory to scan (defaults to Hugging Face cache).", ), ] = None, - disable_tui: Annotated[ + yes: Annotated[ bool, typer.Option( - help="Disable Terminal User Interface (TUI) mode. Useful if your platform/terminal doesn't support the multiselect menu.", + "-y", + "--yes", + help="Skip confirmation prompt.", ), ] = False, - sort: Annotated[ - Optional[SortingOption], + dry_run: Annotated[ + bool, typer.Option( - help="Sort repositories by the specified criteria. Options: 'alphabetical' (A-Z), 'lastUpdated' (newest first), 'lastUsed' (most recent first), 'size' (largest first).", + help="Preview deletions without removing anything.", ), - ] = None, + ] = False, ) -> None: - hf_cache_info = scan_cache_dir(dir) - sort_by = sort.value if sort is not None else None - if disable_tui: - selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[], sort_by=sort_by) - else: - selected_hashes = _manual_review_tui(hf_cache_info, preselected=[], sort_by=sort_by) - if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes: - confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?" - if disable_tui: - confirmed = _ask_for_confirmation_no_tui(confirm_message) - else: - confirmed = _ask_for_confirmation_tui(confirm_message) - if confirmed: - strategy = hf_cache_info.delete_revisions(*selected_hashes) - print("Start deletion.") - strategy.execute() - print( - f"Done. Deleted {len(strategy.repos)} repo(s) and" - f" {len(strategy.snapshots)} revision(s) for a total of" - f" {strategy.expected_freed_size_str}." - ) - return - print("Deletion is cancelled. Do nothing.") - - -def get_table(hf_cache_info: HFCacheInfo, *, verbosity: int = 0) -> str: - if verbosity == 0: - return tabulate( - rows=[ - [ - repo.repo_id, - repo.repo_type, - "{:>12}".format(repo.size_on_disk_str), - repo.nb_files, - repo.last_accessed_str, - repo.last_modified_str, - ", ".join(sorted(repo.refs)), - str(repo.repo_path), - ] - for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) - ], - headers=[ - "REPO ID", - "REPO TYPE", - "SIZE ON DISK", - "NB FILES", - "LAST_ACCESSED", - "LAST_MODIFIED", - "REFS", - "LOCAL PATH", - ], - ) - else: - return tabulate( - rows=[ - [ - repo.repo_id, - repo.repo_type, - revision.commit_hash, - "{:>12}".format(revision.size_on_disk_str), - revision.nb_files, - revision.last_modified_str, - ", ".join(sorted(revision.refs)), - str(revision.snapshot_path), - ] - for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) - for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) - ], - headers=[ - "REPO ID", - "REPO TYPE", - "REVISION", - "SIZE ON DISK", - "NB FILES", - "LAST_MODIFIED", - "REFS", - "LOCAL PATH", - ], - ) + """Remove cached repositories or revisions.""" + try: + hf_cache_info = scan_cache_dir(cache_dir) + except CacheNotFound as exc: + print(f"Cache directory not found: {str(exc.cache_dir)}") + raise typer.Exit(code=1) + resolution = _resolve_deletion_targets(hf_cache_info, targets) -def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[str] = None): - if sort_by == "alphabetical": - return (repo.repo_type, repo.repo_id.lower()) - elif sort_by == "lastUpdated": - return -max(rev.last_modified for rev in repo.revisions) - elif sort_by == "lastUsed": - return -repo.last_accessed - elif sort_by == "size": - return -repo.size_on_disk - else: - return (repo.repo_type, repo.repo_id) - - -@require_inquirer_py -def _manual_review_tui(hf_cache_info: HFCacheInfo, preselected: list[str], sort_by: Optional[str] = None) -> list[str]: - choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected, sort_by=sort_by) - checkbox = inquirer.checkbox( - message="Select revisions to delete:", - choices=choices, - cycle=False, - height=100, - instruction=_get_expectations_str( - hf_cache_info, selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled] - ), - long_instruction="Press to select, to validate and to quit without modification.", - transformer=lambda result: f"{len(result)} revision(s) selected.", - ) + if resolution.missing: + print("Could not find the following targets in the cache:") + for entry in resolution.missing: + print(f" - {entry}") - def _update_expectations(_): - checkbox._instruction = _get_expectations_str( - hf_cache_info, - selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]], - ) + if len(resolution.revisions) == 0: + print("Nothing to delete.") + raise typer.Exit(code=0) - checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations}) - try: - return checkbox.execute() - except KeyboardInterrupt: - return [] + strategy = hf_cache_info.delete_revisions(*sorted(resolution.revisions)) + counts = summarize_deletions(resolution.selected) + summary_parts: list[str] = [] + if counts.repo_count: + summary_parts.append(f"{counts.repo_count} repo(s)") + if counts.partial_revision_count: + summary_parts.append(f"{counts.partial_revision_count} revision(s)") + if not summary_parts: + summary_parts.append(f"{counts.total_revision_count} revision(s)") -@require_inquirer_py -def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: - return inquirer.confirm(message, default=default).execute() + summary_text = " and ".join(summary_parts) + print(f"About to delete {summary_text} totalling {strategy.expected_freed_size_str}.") + print_cache_selected_revisions(resolution.selected) + if dry_run: + print("Dry run: no files were deleted.") + return -def _get_tui_choices_from_scan( - repos: Iterable[CachedRepoInfo], preselected: list[str], sort_by: Optional[str] = None -) -> list: - choices: list[Union["Choice", "Separator"]] = [] - choices.append( - Choice( - _CANCEL_DELETION_STR, name="None of the following (if selected, nothing will be deleted).", enabled=False - ) + if not yes and not typer.confirm("Proceed with deletion?", default=False): + print("Deletion cancelled.") + return + + strategy.execute() + counts = summarize_deletions(resolution.selected) + print( + f"Deleted {counts.repo_count} repo(s) and {counts.total_revision_count} revision(s); freed {strategy.expected_freed_size_str}." ) - sorted_repos = sorted(repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) - for repo in sorted_repos: - choices.append( - Separator( - f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}, used {repo.last_accessed_str})" - ) - ) - for revision in sorted(repo.revisions, key=_revision_sorting_order): - choices.append( - Choice( - revision.commit_hash, - name=( - f"{revision.commit_hash[:8]}: {', '.join(sorted(revision.refs)) or '(detached)'} # modified {revision.last_modified_str}" - ), - enabled=revision.commit_hash in preselected, - ) - ) - return choices -def _manual_review_no_tui( - hf_cache_info: HFCacheInfo, preselected: list[str], sort_by: Optional[str] = None -) -> list[str]: - fd, tmp_path = mkstemp(suffix=".txt") - os.close(fd) - lines = [] - sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) - for repo in sorted_repos: - lines.append( - f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}, used {repo.last_accessed_str})" - ) - for revision in sorted(repo.revisions, key=_revision_sorting_order): - lines.append( - f"{'' if revision.commit_hash in preselected else '#'} {revision.commit_hash} # Refs: {', '.join(sorted(revision.refs)) or '(detached)'} # modified {revision.last_modified_str}" - ) - with open(tmp_path, "w") as f: - f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS) - f.write("\n".join(lines)) - instructions = f""" - TUI is disabled. In order to select which revisions you want to delete, please edit - the following file using the text editor of your choice. Instructions for manual - editing are located at the beginning of the file. Edit the file, save it and confirm - to continue. - File to edit: {ANSI.bold(tmp_path)} - """ - print("\n".join(line.strip() for line in instructions.strip().split("\n"))) - while True: - selected_hashes = _read_manual_review_tmp_file(tmp_path) - if _ask_for_confirmation_no_tui( - _get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?", default=False - ): - break - os.remove(tmp_path) - return sorted(selected_hashes) - - -def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: - YES = ("y", "yes", "1") - NO = ("n", "no", "0") - DEFAULT = "" - ALL = YES + NO + (DEFAULT,) - full_message = message + (" (Y/n) " if default else " (y/N) ") - while True: - answer = input(full_message).lower() - if answer == DEFAULT: - return default - if answer in YES: - return True - if answer in NO: - return False - print(f"Invalid input. Must be one of {ALL}") - - -def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: list[str]) -> str: - if _CANCEL_DELETION_STR in selected_hashes: - return "Nothing will be deleted." - strategy = hf_cache_info.delete_revisions(*selected_hashes) - return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." - - -def _read_manual_review_tmp_file(tmp_path: str) -> list[str]: - with open(tmp_path) as f: - content = f.read() - lines = [line.strip() for line in content.split("\n")] - selected_lines = [line for line in lines if not line.startswith("#")] - selected_hashes = [line.split("#")[0].strip() for line in selected_lines] - return [hash for hash in selected_hashes if len(hash) > 0] - - -_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f""" -# INSTRUCTIONS -# ------------ -# This is a temporary file created by running `hf cache delete --disable-tui`. It contains a set of revisions that can be deleted from your local cache directory. -# -# Please manually review the revisions you want to delete: -# - Revision hashes can be commented out with '#'. -# - Only non-commented revisions in this file will be deleted. -# - Revision hashes that are removed from this file are ignored as well. -# - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and no changes will be applied. -# -# Once you've manually reviewed this file, please confirm deletion in the terminal. This file will be automatically removed once done. -# ------------ +@cache_cli.command() +def prune( + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Cache directory to scan (defaults to Hugging Face cache).", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "-y", + "--yes", + help="Skip confirmation prompt.", + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option( + help="Preview deletions without removing anything.", + ), + ] = False, +) -> None: + """Remove detached revisions from the cache.""" + try: + hf_cache_info = scan_cache_dir(cache_dir) + except CacheNotFound as exc: + print(f"Cache directory not found: {str(exc.cache_dir)}") + raise typer.Exit(code=1) + + selected: dict[CachedRepoInfo, frozenset[CachedRevisionInfo]] = {} + revisions: set[str] = set() + for repo in hf_cache_info.repos: + detached = frozenset(revision for revision in repo.revisions if len(revision.refs) == 0) + if not detached: + continue + selected[repo] = detached + revisions.update(revision.commit_hash for revision in detached) + + if len(revisions) == 0: + print("No unreferenced revisions found. Nothing to prune.") + return -# KILL SWITCH -# ------------ -# Un-comment following line to completely cancel the deletion process -# {_CANCEL_DELETION_STR} -# ------------ + resolution = _DeletionResolution( + revisions=frozenset(revisions), + selected=selected, + missing=(), + ) + strategy = hf_cache_info.delete_revisions(*sorted(resolution.revisions)) + counts = summarize_deletions(selected) -# REVISIONS -# ------------ -""".strip() + print( + f"About to delete {counts.total_revision_count} unreferenced revision(s) ({strategy.expected_freed_size_str} total)." + ) + print_cache_selected_revisions(selected) + if dry_run: + print("Dry run: no files were deleted.") + return + + if not yes and not typer.confirm("Proceed?"): + print("Pruning cancelled.") + return -def _revision_sorting_order(revision: CachedRevisionInfo) -> Any: - return revision.last_modified + strategy.execute() + print(f"Deleted {counts.total_revision_count} unreferenced revision(s); freed {strategy.expected_freed_size_str}.") diff --git a/src/huggingface_hub/utils/_cache_manager.py b/src/huggingface_hub/utils/_cache_manager.py index 3ec03f35e8..f7080b1685 100644 --- a/src/huggingface_hub/utils/_cache_manager.py +++ b/src/huggingface_hub/utils/_cache_manager.py @@ -16,7 +16,6 @@ import os import shutil -import time from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -26,6 +25,7 @@ from ..constants import HF_HUB_CACHE from . import logging +from ._parsing import format_timesince from ._terminal import tabulate @@ -79,7 +79,7 @@ def blob_last_accessed_str(self) -> str: Example: "2 weeks ago". """ - return _format_timesince(self.blob_last_accessed) + return format_timesince(self.blob_last_accessed) @property def blob_last_modified_str(self) -> str: @@ -89,7 +89,7 @@ def blob_last_modified_str(self) -> str: Example: "2 weeks ago". """ - return _format_timesince(self.blob_last_modified) + return format_timesince(self.blob_last_modified) @property def size_on_disk_str(self) -> str: @@ -153,7 +153,7 @@ def last_modified_str(self) -> str: Example: "2 weeks ago". """ - return _format_timesince(self.last_modified) + return format_timesince(self.last_modified) @property def size_on_disk_str(self) -> str: @@ -223,7 +223,7 @@ def last_accessed_str(self) -> str: Example: "2 weeks ago". """ - return _format_timesince(self.last_accessed) + return format_timesince(self.last_accessed) @property def last_modified_str(self) -> str: @@ -233,7 +233,7 @@ def last_modified_str(self) -> str: Example: "2 weeks ago". """ - return _format_timesince(self.last_modified) + return format_timesince(self.last_modified) @property def size_on_disk_str(self) -> str: @@ -244,6 +244,11 @@ def size_on_disk_str(self) -> str: """ return _format_size(self.size_on_disk) + @property + def cache_id(self) -> str: + """Canonical `type/id` identifier used across cache tooling.""" + return f"{self.repo_type}/{self.repo_id}" + @property def refs(self) -> dict[str, CachedRevisionInfo]: """ @@ -816,33 +821,6 @@ def _format_size(num: int) -> str: return f"{num_f:.1f}Y" -_TIMESINCE_CHUNKS = ( - # Label, divider, max value - ("second", 1, 60), - ("minute", 60, 60), - ("hour", 60 * 60, 24), - ("day", 60 * 60 * 24, 6), - ("week", 60 * 60 * 24 * 7, 6), - ("month", 60 * 60 * 24 * 30, 11), - ("year", 60 * 60 * 24 * 365, None), -) - - -def _format_timesince(ts: float) -> str: - """Format timestamp in seconds into a human-readable string, relative to now. - - Vaguely inspired by Django's `timesince` formatter. - """ - delta = time.time() - ts - if delta < 20: - return "a few seconds ago" - for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007 - value = round(delta / divider) - if max_value is not None and value <= max_value: - break - return f"{value} {label}{'s' if value > 1 else ''} ago" - - def _try_delete_path(path: Path, path_type: str) -> None: """Try to delete a local file or folder. diff --git a/src/huggingface_hub/utils/_parsing.py b/src/huggingface_hub/utils/_parsing.py new file mode 100644 index 0000000000..dbb16c6bab --- /dev/null +++ b/src/huggingface_hub/utils/_parsing.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parsing helpers shared across modules.""" + +import re +import time +from typing import Dict + + +RE_NUMBER_WITH_UNIT = re.compile(r"(\d+)([a-z]+)", re.IGNORECASE) + +BYTE_UNITS: Dict[str, int] = { + "k": 1_000, + "m": 1_000_000, + "g": 1_000_000_000, + "t": 1_000_000_000_000, + "p": 1_000_000_000_000_000, +} + +TIME_UNITS: Dict[str, int] = { + "s": 1, + "m": 60, + "h": 60 * 60, + "d": 24 * 60 * 60, + "w": 7 * 24 * 60 * 60, + "mo": 30 * 24 * 60 * 60, + "y": 365 * 24 * 60 * 60, +} + + +def parse_size(value: str) -> int: + """Parse a size expressed as a string with digits and unit (like `"10MB"`) to an integer (in bytes).""" + return _parse_with_unit(value, BYTE_UNITS) + + +def parse_duration(value: str) -> int: + """Parse a duration expressed as a string with digits and unit (like `"10s"`) to an integer (in seconds).""" + return _parse_with_unit(value, TIME_UNITS) + + +def _parse_with_unit(value: str, units: Dict[str, int]) -> int: + """Parse a numeric value with optional unit.""" + stripped = value.strip() + if not stripped: + raise ValueError("Value cannot be empty.") + try: + return int(value) + except ValueError: + pass + + match = RE_NUMBER_WITH_UNIT.fullmatch(stripped) + if not match: + raise ValueError(f"Invalid value '{value}'. Must match pattern '\\d+[a-z]+' or be a plain number.") + + number = int(match.group(1)) + unit = match.group(2).lower() + + if unit not in units: + raise ValueError(f"Unknown unit '{unit}'. Must be one of {list(units.keys())}.") + + return number * units[unit] + + +def format_timesince(ts: float) -> str: + """Format timestamp in seconds into a human-readable string, relative to now. + + Vaguely inspired by Django's `timesince` formatter. + """ + _TIMESINCE_CHUNKS = ( + # Label, divider, max value + ("second", 1, 60), + ("minute", 60, 60), + ("hour", 60 * 60, 24), + ("day", 60 * 60 * 24, 6), + ("week", 60 * 60 * 24 * 7, 6), + ("month", 60 * 60 * 24 * 30, 11), + ("year", 60 * 60 * 24 * 365, None), + ) + delta = time.time() - ts + if delta < 20: + return "a few seconds ago" + for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007 + value = round(delta / divider) + if max_value is not None and value <= max_value: + break + return f"{value} {label}{'s' if value > 1 else ''} ago" diff --git a/tests/test_cli.py b/tests/test_cli.py index 7ea7d084c6..bba09076af 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,4 @@ +import json import os import warnings from contextlib import contextmanager @@ -10,12 +11,12 @@ from typer.testing import CliRunner from huggingface_hub.cli._cli_utils import RepoType -from huggingface_hub.cli.cache import _CANCEL_DELETION_STR +from huggingface_hub.cli.cache import CacheDeletionCounts from huggingface_hub.cli.download import download from huggingface_hub.cli.hf import app from huggingface_hub.cli.upload import _resolve_upload_paths, upload from huggingface_hub.errors import RevisionNotFoundError -from huggingface_hub.utils import SoftTemporaryDirectory +from huggingface_hub.utils import CachedRepoInfo, CachedRevisionInfo, SoftTemporaryDirectory from .testing_utils import DUMMY_MODEL_ID @@ -25,52 +26,186 @@ def runner() -> CliRunner: return CliRunner() +def _make_revision(commit_hash: str, *, refs: Optional[set[str]] = None) -> CachedRevisionInfo: + return CachedRevisionInfo( + commit_hash=commit_hash, + snapshot_path=Path(f"/tmp/{commit_hash}"), + size_on_disk=0, + files=frozenset(), + refs=frozenset(refs or set()), + last_modified=0.0, + ) + + +def _make_repo(repo_id: str, *, revisions: list[CachedRevisionInfo]) -> CachedRepoInfo: + return CachedRepoInfo( + repo_id=repo_id, + repo_type="model", + repo_path=Path(f"/tmp/{repo_id.replace('/', '_')}"), + size_on_disk=0, + nb_files=0, + revisions=frozenset(revisions), + last_accessed=0.0, + last_modified=0.0, + ) + + class TestCacheCommand: - def test_scan_cache_basic(self, runner: CliRunner) -> None: - with patch("huggingface_hub.cli.cache.scan_cache_dir") as mock_run: - result = runner.invoke(app, ["cache", "scan"]) + def test_ls_table_output(self, runner: CliRunner) -> None: + repo = _make_repo("user/model", revisions=[_make_revision("a" * 40, refs={"main"})]) + entries = [(repo, None)] + repo_refs_map = {repo: frozenset({"main"})} + + with ( + patch("huggingface_hub.cli.cache.scan_cache_dir"), + patch( + "huggingface_hub.cli.cache.collect_cache_entries", + return_value=(entries, repo_refs_map), + ), + ): + result = runner.invoke(app, ["cache", "ls"]) + assert result.exit_code == 0 - mock_run.assert_called_once_with(None) + stdout = result.stdout + assert "model/user/model" in stdout + assert "main" in stdout + + def test_ls_json_with_filter_and_revisions(self, runner: CliRunner) -> None: + revision = _make_revision("b" * 40, refs={"main"}) + repo = _make_repo("user/model", revisions=[revision]) + entries = [(repo, revision)] + repo_refs_map = {repo: frozenset({"main"})} + + def true_filter(repo: CachedRepoInfo, revision_obj: Optional[CachedRevisionInfo], now: float) -> bool: + return True - def test_scan_cache_verbose(self, runner: CliRunner) -> None: with ( - patch("huggingface_hub.cli.cache.scan_cache_dir") as mock_run, - patch("huggingface_hub.cli.cache.get_table") as get_table_mock, + patch("huggingface_hub.cli.cache.scan_cache_dir"), + patch( + "huggingface_hub.cli.cache.collect_cache_entries", + return_value=(entries, repo_refs_map), + ), + patch( + "huggingface_hub.cli.cache.compile_cache_filter", + return_value=true_filter, + ) as compile_mock, + ): + result = runner.invoke( + app, + ["cache", "ls", "--revisions", "--filter", "size>1", "--format", "json"], + ) + + assert result.exit_code == 0 + compile_mock.assert_called_once_with("size>1", repo_refs_map) + payload = json.loads(result.stdout) + assert payload and payload[0]["revision"] == revision.commit_hash + + def test_ls_quiet_revisions(self, runner: CliRunner) -> None: + revision = _make_revision("c" * 40, refs=set()) + repo = _make_repo("user/model", revisions=[revision]) + entries = [(repo, revision)] + repo_refs_map = {repo: frozenset()} + + with ( + patch("huggingface_hub.cli.cache.scan_cache_dir"), + patch( + "huggingface_hub.cli.cache.collect_cache_entries", + return_value=(entries, repo_refs_map), + ), ): - result = runner.invoke(app, ["cache", "scan", "-v"]) + result = runner.invoke(app, ["cache", "ls", "--revisions", "--quiet"]) + assert result.exit_code == 0 - mock_run.assert_called_once_with(None) - get_table_mock.assert_called_once_with(mock_run.return_value, verbosity=1) + assert result.stdout.strip() == revision.commit_hash + + def test_rm_revision_executes_strategy(self, runner: CliRunner) -> None: + revision = _make_revision("c" * 40) + repo = _make_repo("user/model", revisions=[revision]) + + repo_lookup = {"model/user/model": repo} + revision_lookup = {revision.commit_hash.lower(): (repo, revision)} + + strategy = Mock() + strategy.expected_freed_size_str = "0B" + + hf_cache_info = Mock() + hf_cache_info.delete_revisions.return_value = strategy + + counts = CacheDeletionCounts(repo_count=0, partial_revision_count=1, total_revision_count=1) + + with ( + patch("huggingface_hub.cli.cache.scan_cache_dir", return_value=hf_cache_info), + patch("huggingface_hub.cli.cache.build_cache_index", return_value=(repo_lookup, revision_lookup)), + patch( + "huggingface_hub.cli.cache.summarize_deletions", + return_value=counts, + ), + patch("huggingface_hub.cli.cache.print_cache_selected_revisions") as print_mock, + ): + result = runner.invoke(app, ["cache", "rm", revision.commit_hash, "--yes"]) - def test_scan_cache_with_dir(self, runner: CliRunner) -> None: - with patch("huggingface_hub.cli.cache.scan_cache_dir") as mock_run: - result = runner.invoke(app, ["cache", "scan", "--dir", "something"]) assert result.exit_code == 0 - mock_run.assert_called_once_with("something") + hf_cache_info.delete_revisions.assert_called_once_with(revision.commit_hash) + strategy.execute.assert_called_once_with() + print_mock.assert_called_once() + + def test_rm_dry_run_skips_execute(self, runner: CliRunner) -> None: + revision = _make_revision("d" * 40) + repo = _make_repo("user/model", revisions=[revision]) + repo_lookup = {"model/user/model": repo} + revision_lookup = {revision.commit_hash.lower(): (repo, revision)} + + strategy = Mock() + strategy.expected_freed_size_str = "0B" + + hf_cache_info = Mock() + hf_cache_info.delete_revisions.return_value = strategy + + counts = CacheDeletionCounts(repo_count=0, partial_revision_count=1, total_revision_count=1) - def test_scan_cache_ultra_verbose(self, runner: CliRunner) -> None: with ( - patch("huggingface_hub.cli.cache.scan_cache_dir") as mock_run, - patch("huggingface_hub.cli.cache.get_table") as get_table_mock, + patch("huggingface_hub.cli.cache.scan_cache_dir", return_value=hf_cache_info), + patch("huggingface_hub.cli.cache.build_cache_index", return_value=(repo_lookup, revision_lookup)), + patch( + "huggingface_hub.cli.cache.summarize_deletions", + return_value=counts, + ), + patch("huggingface_hub.cli.cache.print_cache_selected_revisions"), ): - result = runner.invoke(app, ["cache", "scan", "-vvv"]) + result = runner.invoke(app, ["cache", "rm", revision.commit_hash, "--dry-run"]) + assert result.exit_code == 0 - mock_run.assert_called_once_with(None) - get_table_mock.assert_called_once_with(mock_run.return_value, verbosity=3) + hf_cache_info.delete_revisions.assert_called_once_with(revision.commit_hash) + strategy.execute.assert_not_called() + + def test_prune_dry_run(self, runner: CliRunner) -> None: + referenced = _make_revision("e" * 40, refs={"main"}) + detached = _make_revision("f" * 40, refs=set()) + repo = _make_repo("user/model", revisions=[referenced, detached]) - def test_delete_cache_with_dir(self, runner: CliRunner) -> None: hf_cache_info = Mock() + hf_cache_info.repos = frozenset({repo}) + + strategy = Mock() + strategy.expected_freed_size_str = "0B" + hf_cache_info.delete_revisions.return_value = strategy + + counts = CacheDeletionCounts(repo_count=0, partial_revision_count=1, total_revision_count=1) + with ( - patch("huggingface_hub.cli.cache.scan_cache_dir", return_value=hf_cache_info) as scan_mock, + patch("huggingface_hub.cli.cache.scan_cache_dir", return_value=hf_cache_info), patch( - "huggingface_hub.cli.cache._manual_review_tui", - return_value=[_CANCEL_DELETION_STR], - ) as review_mock, + "huggingface_hub.cli.cache.summarize_deletions", + return_value=counts, + ), + patch("huggingface_hub.cli.cache.print_cache_selected_revisions") as print_mock, ): - result = runner.invoke(app, ["cache", "delete", "--dir", "something"]) + result = runner.invoke(app, ["cache", "prune", "--dry-run"]) + assert result.exit_code == 0 - scan_mock.assert_called_once_with("something") - review_mock.assert_called_once_with(hf_cache_info, preselected=[], sort_by=None) + hf_cache_info.delete_revisions.assert_called_once_with(detached.commit_hash) + strategy.execute.assert_not_called() + print_mock.assert_called_once() class TestUploadCommand: diff --git a/tests/test_command_delete_cache.py b/tests/test_command_delete_cache.py deleted file mode 100644 index 49041ea6ec..0000000000 --- a/tests/test_command_delete_cache.py +++ /dev/null @@ -1,354 +0,0 @@ -import os -import unittest -from pathlib import Path -from tempfile import mkstemp -from unittest.mock import Mock, patch - -from InquirerPy.base.control import Choice -from InquirerPy.separator import Separator - -from huggingface_hub.cli.cache import ( - _CANCEL_DELETION_STR, - _ask_for_confirmation_no_tui, - _get_expectations_str, - _get_tui_choices_from_scan, - _manual_review_no_tui, - _read_manual_review_tmp_file, -) -from huggingface_hub.utils import SoftTemporaryDirectory, capture_output - - -class TestDeleteCacheHelpers(unittest.TestCase): - def test_get_tui_choices_from_scan_empty(self) -> None: - choices = _get_tui_choices_from_scan(repos={}, preselected=[], sort_by=None) - assert len(choices) == 1 - assert isinstance(choices[0], Choice) - assert choices[0].value == _CANCEL_DELETION_STR - assert len(choices[0].name) != 0 # Something displayed to the user - assert not choices[0].enabled - - def test_get_tui_choices_from_scan_with_preselection(self) -> None: - choices = _get_tui_choices_from_scan( - repos=_get_cache_mock().repos, - preselected=[ - "dataset_revision_hash_id", # dataset_1 is preselected - "a_revision_id_that_does_not_exist", # unknown but will not complain - "older_hash_id", # only the oldest revision from model_2 - ], - sort_by=None, # Don't sort to maintain original order - ) - assert len(choices) == 8 - - # Item to cancel everything - assert isinstance(choices[0], Choice) - assert choices[0].value == _CANCEL_DELETION_STR - assert len(choices[0].name) != 0 - assert not choices[0].enabled - - # Dataset repo separator - assert isinstance(choices[1], Separator) - assert choices[1]._line == "\nDataset dummy_dataset (8M, used 2 weeks ago)" - - # Only revision of `dummy_dataset` - assert isinstance(choices[2], Choice) - assert choices[2].value == "dataset_revision_hash_id" - assert choices[2].name == "dataset_: (detached) # modified 1 day ago" - assert choices[2].enabled # preselected - - # Model `dummy_model` separator - assert isinstance(choices[3], Separator) - assert choices[3]._line == "\nModel dummy_model (1.4K, used 2 years ago)" - - # Recent revision of `dummy_model` (appears first due to sorting by last_modified) - assert isinstance(choices[4], Choice) - assert choices[4].value == "recent_hash_id" - assert choices[4].name == "recent_h: main # modified 2 years ago" - assert not choices[4].enabled - - # Oldest revision of `dummy_model` - assert isinstance(choices[5], Choice) - assert choices[5].value == "older_hash_id" - assert choices[5].name == "older_ha: (detached) # modified 3 years ago" - assert choices[5].enabled # preselected - - # Model `gpt2` separator - assert isinstance(choices[6], Separator) - assert choices[6]._line == "\nModel gpt2 (3.6G, used 2 hours ago)" - - # Only revision of `gpt2` - assert isinstance(choices[7], Choice) - assert choices[7].value == "abcdef123456789" - assert choices[7].name == "abcdef12: main, refs/pr/1 # modified 2 years ago" - assert not choices[7].enabled - - def test_get_tui_choices_from_scan_with_sort_size(self) -> None: - """Test sorting by size.""" - choices = _get_tui_choices_from_scan(repos=_get_cache_mock().repos, preselected=[], sort_by="size") - - # Verify repo order: gpt2 (3.6G) -> dummy_dataset (8M) -> dummy_model (1.4K) - assert isinstance(choices[1], Separator) - assert "gpt2" in choices[1]._line - - assert isinstance(choices[3], Separator) - assert "dummy_dataset" in choices[3]._line - - assert isinstance(choices[5], Separator) - assert "dummy_model" in choices[5]._line - - def test_get_expectations_str_on_no_deletion_item(self) -> None: - """Test `_get_instructions` when `_CANCEL_DELETION_STR` is passed.""" - assert ( - _get_expectations_str( - hf_cache_info=Mock(), - selected_hashes=["hash_1", _CANCEL_DELETION_STR, "hash_2"], - ) - == "Nothing will be deleted." - ) - - def test_get_expectations_str_with_selection(self) -> None: - """Test `_get_instructions` with 2 revisions selected.""" - strategy_mock = Mock() - strategy_mock.expected_freed_size_str = "5.1M" - - cache_mock = Mock() - cache_mock.delete_revisions.return_value = strategy_mock - - assert ( - _get_expectations_str( - hf_cache_info=cache_mock, - selected_hashes=["hash_1", "hash_2"], - ) - == "2 revisions selected counting for 5.1M." - ) - cache_mock.delete_revisions.assert_called_once_with("hash_1", "hash_2") - - def test_read_manual_review_tmp_file(self) -> None: - """Test `_read_manual_review_tmp_file`.""" - - with SoftTemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) / "file.txt" - - with tmp_path.open("w") as f: - f.writelines( - [ - "# something commented out\n", - "###\n", - "\n\n\n\n", # some empty lines - " # Something commented out after spaces\n", - "a_revision_hash\n", - "a_revision_hash_with_a_comment # 2 years ago\n", - " a_revision_hash_after_spaces\n", - " a_revision_hash_with_a_comment_after_spaces # 2years ago\n", - " # hash_commented_out # 2 years ago\n", - "a_revision_hash\n", # Duplicate - "", # empty line - ] - ) - - # Only non-commented lines are returned - # Order is kept and lines are not de-duplicated - assert _read_manual_review_tmp_file(tmp_path) == [ - "a_revision_hash", - "a_revision_hash_with_a_comment", - "a_revision_hash_after_spaces", - "a_revision_hash_with_a_comment_after_spaces", - "a_revision_hash", - ] - - @patch("huggingface_hub.cli.cache.input") - @patch("huggingface_hub.cli.cache.mkstemp") - def test_manual_review_no_tui(self, mock_mkstemp: Mock, mock_input: Mock) -> None: - # Mock file creation so that we know the file location in test - fd, tmp_path = mkstemp() - mock_mkstemp.return_value = fd, tmp_path - - # Mock cache - cache_mock = _get_cache_mock() - - # Mock input from user - def _input_answers(): - self.assertTrue(os.path.isfile(tmp_path)) # not deleted yet - with open(tmp_path) as f: - content = f.read() - self.assertTrue(content.startswith("# INSTRUCTIONS")) - - # older_hash_id is not commented - self.assertIn("\n older_hash_id # Refs: (detached)", content) - # same for abcdef123456789 - self.assertIn("\n abcdef123456789 # Refs: main, refs/pr/1", content) - # dataset revision is not preselected - self.assertIn("# dataset_revision_hash_id", content) - # same for recent_hash_id - self.assertIn("# recent_hash_id", content) - - # Select dataset revision - content = content.replace("# dataset_revision_hash_id", "dataset_revision_hash_id") - # Deselect abcdef123456789 - content = content.replace("abcdef123456789", "# abcdef123456789") - with open(tmp_path, "w") as f: - f.write(content) - - yield "no" # User edited the file and want to see the strategy diff - yield "y" # User confirms - - mock_input.side_effect = _input_answers() - - # Run manual review - with capture_output() as output: - selected_hashes = _manual_review_no_tui( - hf_cache_info=cache_mock, - preselected=["abcdef123456789", "older_hash_id"], - sort_by=None, - ) - - # Tmp file has been created but is now deleted - mock_mkstemp.assert_called_once_with(suffix=".txt") - assert not os.path.isfile(tmp_path) # now deleted - - # User changed the selection - assert selected_hashes == ["dataset_revision_hash_id", "older_hash_id"] - - # Check printed instructions - printed = output.getvalue() - assert printed.startswith("TUI is disabled. In order to") - assert str(tmp_path) in printed - - # Check input called twice - assert mock_input.call_count == 2 - - @patch("huggingface_hub.cli.cache.input") - def test_ask_for_confirmation_no_tui(self, mock_input: Mock) -> None: - """Test `_ask_for_confirmation_no_tui`.""" - # Answer yes - mock_input.side_effect = ("y",) - value = _ask_for_confirmation_no_tui("custom message 1", default=True) - mock_input.assert_called_with("custom message 1 (Y/n) ") - assert value - - # Answer no - mock_input.side_effect = ("NO",) - value = _ask_for_confirmation_no_tui("custom message 2", default=True) - mock_input.assert_called_with("custom message 2 (Y/n) ") - assert not value - - # Answer invalid, then default - mock_input.side_effect = ("foo", "") - with capture_output() as output: - value = _ask_for_confirmation_no_tui("custom message 3", default=False) - mock_input.assert_called_with("custom message 3 (y/N) ") - assert not value - assert output.getvalue() == "Invalid input. Must be one of ('y', 'yes', '1', 'n', 'no', '0', '')\n" - - def test_get_tui_choices_from_scan_with_different_sorts(self) -> None: - """Test different sorting modes.""" - cache_mock = _get_cache_mock() - - # Test size sorting (largest first) - order: gpt2 (3.6G) -> dummy_dataset (8M) -> dummy_model (1.4K) - size_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="size") - # Separators at positions 1, 3, 5 - assert isinstance(size_choices[1], Separator) - assert "gpt2" in size_choices[1]._line - assert isinstance(size_choices[3], Separator) - assert "dummy_dataset" in size_choices[3]._line - assert isinstance(size_choices[5], Separator) - assert "dummy_model" in size_choices[5]._line - - # Test alphabetical sorting - order: dummy_dataset -> dummy_model -> gpt2 - alpha_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="alphabetical") - # Separators at positions 1, 3, 6 (dummy_model has 2 revisions) - assert isinstance(alpha_choices[1], Separator) - assert "dummy_dataset" in alpha_choices[1]._line - assert isinstance(alpha_choices[3], Separator) - assert "dummy_model" in alpha_choices[3]._line - assert isinstance(alpha_choices[6], Separator) - assert "gpt2" in alpha_choices[6]._line - - # Test lastUpdated sorting - order: dummy_dataset (1 day) -> gpt2 (2 years) -> dummy_model (3 years) - updated_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="lastUpdated") - # Separators at positions 1, 3, 5 - assert isinstance(updated_choices[1], Separator) - assert "dummy_dataset" in updated_choices[1]._line - assert isinstance(updated_choices[3], Separator) - assert "gpt2" in updated_choices[3]._line - assert isinstance(updated_choices[5], Separator) - assert "dummy_model" in updated_choices[5]._line - - # Test lastUsed sorting - order: gpt2 (2h) -> dummy_dataset (2w) -> dummy_model (2y) - used_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="lastUsed") - # Separators at positions 1, 3, 5 - assert isinstance(used_choices[1], Separator) - assert "gpt2" in used_choices[1]._line - assert isinstance(used_choices[3], Separator) - assert "dummy_dataset" in used_choices[3]._line - assert isinstance(used_choices[5], Separator) - assert "dummy_model" in used_choices[5]._line - - -def _get_cache_mock() -> Mock: - # First model with 1 revision - model_1 = Mock() - model_1.repo_type = "model" - model_1.repo_id = "gpt2" - model_1.size_on_disk_str = "3.6G" - model_1.last_accessed = 1660000000 - model_1.last_accessed_str = "2 hours ago" - model_1.size_on_disk = 3.6 * 1024**3 # 3.6 GiB - - model_1_revision_1 = Mock() - model_1_revision_1.commit_hash = "abcdef123456789" - model_1_revision_1.refs = {"main", "refs/pr/1"} - model_1_revision_1.last_modified = 123456789000 # 2 years ago - model_1_revision_1.last_modified_str = "2 years ago" - - model_1.revisions = {model_1_revision_1} - - # Second model with 2 revisions - model_2 = Mock() - model_2.repo_type = "model" - model_2.repo_id = "dummy_model" - model_2.size_on_disk_str = "1.4K" - model_2.last_accessed = 1550000000 - model_2.last_accessed_str = "2 years ago" - model_2.size_on_disk = 1.4 * 1024 # 1.4K - - model_2_revision_1 = Mock() - model_2_revision_1.commit_hash = "recent_hash_id" - model_2_revision_1.refs = {"main"} - model_2_revision_1.last_modified = 123456789 # 2 years ago - model_2_revision_1.last_modified_str = "2 years ago" - - model_2_revision_2 = Mock() - model_2_revision_2.commit_hash = "older_hash_id" - model_2_revision_2.refs = {} - model_2_revision_2.last_modified = 12345678000 # 3 years ago - model_2_revision_2.last_modified_str = "3 years ago" - - model_2.revisions = {model_2_revision_1, model_2_revision_2} - - # And a dataset with 1 revision - dataset_1 = Mock() - dataset_1.repo_type = "dataset" - dataset_1.repo_id = "dummy_dataset" - dataset_1.size_on_disk_str = "8M" - dataset_1.last_accessed = 1659000000 - dataset_1.last_accessed_str = "2 weeks ago" - dataset_1.size_on_disk = 8 * 1024**2 # 8 MiB - - dataset_1_revision_1 = Mock() - dataset_1_revision_1.commit_hash = "dataset_revision_hash_id" - dataset_1_revision_1.refs = {} - dataset_1_revision_1.last_modified = 1234567890000 # 1 day ago (newest) - dataset_1_revision_1.last_modified_str = "1 day ago" - - dataset_1.revisions = {dataset_1_revision_1} - - # Fake cache - strategy_mock = Mock() - strategy_mock.repos = [] - strategy_mock.snapshots = [] - strategy_mock.expected_freed_size_str = "5.1M" - - cache_mock = Mock() - cache_mock.repos = {model_1, model_2, dataset_1} - cache_mock.delete_revisions.return_value = strategy_mock - return cache_mock diff --git a/tests/test_utils_cache.py b/tests/test_utils_cache.py index e7d5d350a3..e3a7edf806 100644 --- a/tests/test_utils_cache.py +++ b/tests/test_utils_cache.py @@ -1,17 +1,14 @@ import os -import tempfile import time import unittest from pathlib import Path -from typing import Any from unittest.mock import Mock import pytest from huggingface_hub._snapshot_download import snapshot_download -from huggingface_hub.cli.cache import cache_scan -from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, _format_size, capture_output, scan_cache_dir -from huggingface_hub.utils._cache_manager import CacheNotFound, _format_timesince, _try_delete_path +from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, _format_size, scan_cache_dir +from huggingface_hub.utils._cache_manager import CacheNotFound, _try_delete_path from .testing_utils import rmtree_with_retry, with_production_testing, xfail_on_windows @@ -230,71 +227,6 @@ def test_scan_cache_on_valid_cache_windows(self) -> None: pr_1_readme_file.blob_path, main_readme_file.blob_path ) - @xfail_on_windows("Size on disk and paths differ on Windows. Not useful to test.") - def test_cli_scan_cache_quiet(self) -> None: - """Test output from CLI scan cache with non verbose output. - - End-to-end test just to see if output is in expected format. - """ - with capture_output() as output: - cache_scan(dir=self.cache_dir, verbose=0) - - expected_output = f""" - REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH - ----------------------------- --------- ------------ -------- ----------------- ----------------- --------------- --------------------------------------------------------- - {DATASET_ID} dataset 2.3K 1 a few seconds ago a few seconds ago main {self.cache_dir}/{DATASET_PATH} - {MODEL_ID} model 1.5K 3 a few seconds ago a few seconds ago main, refs/pr/1 {self.cache_dir}/{MODEL_PATH} - - Done in 0.0s. Scanned 2 repo(s) for a total of \x1b[1m\x1b[31m3.8K\x1b[0m. - """ - - assert is_sublist( - expected_output.replace("-", "").split(), - output.getvalue().replace("-", "").split(), - ) - - @xfail_on_windows("Size on disk and paths differ on Windows. Not useful to test.") - def test_cli_scan_cache_verbose(self) -> None: - """Test output from CLI scan cache with verbose output. - - End-to-end test just to see if output is in expected format. - """ - with capture_output() as output: - cache_scan(dir=self.cache_dir, verbose=1) - - expected_output = f""" - REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH - ----------------------------- --------- ---------------------------------------- ------------ -------- ----------------- --------- ------------------------------------------------------------------------------------------------------------ - {DATASET_ID} dataset {REPO_B_MAIN_HASH} 2.3K 1 a few seconds ago main {self.cache_dir}/{DATASET_PATH}/snapshots/{REPO_B_MAIN_HASH} - {MODEL_ID} model {REPO_A_PR_1_HASH} 1.5K 3 a few seconds ago refs/pr/1 {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_PR_1_HASH} - {MODEL_ID} model {REPO_A_MAIN_HASH} 1.5K 2 a few seconds ago main {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_MAIN_HASH} - {MODEL_ID} model {REPO_A_OTHER_HASH} 1.5K 1 a few seconds ago {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_OTHER_HASH} - - Done in 0.0s. Scanned 2 repo(s) for a total of \x1b[1m\x1b[31m3.8K\x1b[0m. - """ - - assert is_sublist( - expected_output.replace("-", "").split(), - output.getvalue().replace("-", "").split(), - ) - - def test_cli_scan_missing_cache(self) -> None: - """Test output from CLI scan cache when cache does not exist. - - End-to-end test just to see if output is in expected format. - """ - tmp_dir = tempfile.mkdtemp() - os.rmdir(tmp_dir) - - with capture_output() as output: - cache_scan(dir=tmp_dir, verbose=0) - - expected_output = f""" - Cache directory not found: {Path(tmp_dir).resolve()} - """ - - assert is_sublist(expected_output.split(), output.getvalue().split()) - @pytest.mark.usefixtures("fx_cache_dir") class TestCorruptedCacheUtils(unittest.TestCase): @@ -822,17 +754,3 @@ def test_format_size(self) -> None: expected, msg=f"Wrong formatting for {size} == '{expected}'", ) - - def test_format_timesince(self) -> None: - """Test `_format_timesince` formatter.""" - for ts, expected in self.SINCE.items(): - self.assertEqual( - _format_timesince(time.time() - ts), - expected, - msg=f"Wrong formatting for {ts} == '{expected}'", - ) - - -def is_sublist(sub: list[Any], full: list[Any]) -> bool: - it = iter(full) - return all(item in it for item in sub) diff --git a/tests/test_utils_parsing.py b/tests/test_utils_parsing.py new file mode 100644 index 0000000000..1723cfe7ac --- /dev/null +++ b/tests/test_utils_parsing.py @@ -0,0 +1,89 @@ +import time + +import pytest + +from huggingface_hub.utils._parsing import format_timesince, parse_duration, parse_size + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("10", 10), + ("10k", 10_000), + ("5M", 5_000_000), + ("2G", 2_000_000_000), + ("1T", 1_000_000_000_000), + ("0", 0), + ], +) +def test_parse_size_valid(value, expected): + assert parse_size(value) == expected + + +@pytest.mark.parametrize( + "value", + [ + "1.5G", + "3KB", + "-5M", + "10X", + "abc", + "", + "123abc456", + " 10 K", + ], +) +def test_parse_size_invalid(value): + with pytest.raises(ValueError): + parse_size(value) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("10s", 10), + ("5m", 300), + ("2h", 7_200), + ("1d", 86_400), + ("1w", 604_800), + ("1mo", 2_592_000), + ("1y", 31_536_000), + ("0", 0), + ], +) +def test_parse_duration_valid(value, expected): + assert parse_duration(value) == expected + + +@pytest.mark.parametrize( + "value", + [ + "1.5h", + "3month", + "-5m", + "10X", + "abc", + "", + "123abc456", + " 10 m", + ], +) +def test_parse_duration_invalid(value): + with pytest.raises(ValueError): + parse_duration(value) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (1, "a few seconds ago"), + (15, "a few seconds ago"), + (25, "25 seconds ago"), + (80, "1 minute ago"), + (1000, "17 minutes ago"), + (4000, "1 hour ago"), + (8000, "2 hours ago"), + ], +) +def test_format_timesince(value, expected): + assert format_timesince(time.time() - value) == expected