Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ODSC-6682] Delete HF cache by default while registering models #1044

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""AQUA utils and constants."""

Expand All @@ -11,6 +11,7 @@
import random
import re
import shlex
import shutil
import subprocess
from datetime import datetime, timedelta
from functools import wraps
Expand All @@ -21,6 +22,8 @@
import fsspec
import oci
from cachetools import TTLCache, cached
from huggingface_hub.constants import HF_HUB_CACHE
from huggingface_hub.file_download import repo_folder_name
from huggingface_hub.hf_api import HfApi, ModelInfo
from huggingface_hub.utils import (
GatedRepoError,
Expand Down Expand Up @@ -788,7 +791,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""


def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
def upload_folder(
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
) -> str:
"""Upload the local folder to the object storage

Args:
Expand Down Expand Up @@ -818,6 +823,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path


def cleanup_local_hf_model_artifact(
model_name: str,
local_dir: str = None,
):
"""
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
Parameters
----------
model_name (str): Name of the huggingface model
local_dir (str): Local directory where the object is downloaded

"""
if local_dir and os.path.exists(local_dir):
model_dir = os.path.join(local_dir, model_name)
model_dir = (
os.path.dirname(model_dir)
if "/" in model_name or os.sep in model_name
else model_dir
)
shutil.rmtree(model_dir, ignore_errors=True)
if os.path.exists(model_dir):
logger.debug(
f"Could not delete local model artifact directory: {model_dir}"
)
else:
logger.debug(f"Deleted local model artifact directory: {model_dir}.")

hf_local_path = os.path.join(
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
)
shutil.rmtree(hf_local_path, ignore_errors=True)

if os.path.exists(hf_local_path):
logger.debug(
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
)
else:
logger.debug(
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
)


def is_service_managed_container(container):
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)

Expand Down
8 changes: 7 additions & 1 deletion ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Optional
Expand Down Expand Up @@ -128,6 +128,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
download_from_hf = (
str(input_data.get("download_from_hf", "false")).lower() == "true"
)
local_dir = input_data.get("local_dir")
cleanup_model_cache = (
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
)
inference_container_uri = input_data.get("inference_container_uri")
allow_patterns = input_data.get("allow_patterns")
ignore_patterns = input_data.get("ignore_patterns")
Expand All @@ -139,6 +143,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
model=model,
os_path=os_path,
download_from_hf=download_from_hf,
local_dir=local_dir,
cleanup_model_cache=cleanup_model_cache,
inference_container=inference_container,
finetuning_container=finetuning_container,
compartment_id=compartment_id,
Expand Down
3 changes: 2 additions & 1 deletion ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

"""
Expand Down Expand Up @@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
os_path: str
download_from_hf: Optional[bool] = True
local_dir: Optional[str] = None
cleanup_model_cache: Optional[bool] = True
inference_container: Optional[str] = None
finetuning_container: Optional[str] = None
compartment_id: Optional[str] = None
Expand Down
23 changes: 16 additions & 7 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import os
import pathlib
Expand All @@ -23,6 +23,7 @@
from ads.aqua.common.utils import (
LifecycleStatus,
_build_resource_identifier,
cleanup_local_hf_model_artifact,
copy_model_config,
create_word_icon,
generate_tei_cmd_var,
Expand Down Expand Up @@ -1322,20 +1323,18 @@ def _download_model_from_hf(
Returns
-------
model_artifact_path (str): Location where the model artifacts are downloaded.

"""
# Download the model from hub
if not local_dir:
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
local_dir = os.path.join(local_dir, model_name)
os.makedirs(local_dir, exist_ok=True)
if local_dir:
local_dir = os.path.join(local_dir, model_name)
os.makedirs(local_dir, exist_ok=True)
snapshot_download(
repo_id=model_name,
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
# Upload to object storage and skip .cache/huggingface/ folder
# Upload to object storage
model_artifact_path = upload_folder(
os_path=os_path,
local_dir=local_dir,
Expand Down Expand Up @@ -1365,6 +1364,8 @@ def register(
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
registered. Set to True by default.

Returns:
AquaModel:
Expand Down Expand Up @@ -1474,6 +1475,14 @@ def register(
detail=validation_result.telemetry_model_name,
)

if (
import_model_details.download_from_hf
and import_model_details.cleanup_model_cache
):
cleanup_local_hf_model_artifact(
model_name=model_name, local_dir=import_model_details.local_dir
)

return AquaModel(**aqua_model_attributes)

def _if_show(self, model: DataScienceModel) -> bool:
Expand Down
40 changes: 28 additions & 12 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
Expand Down Expand Up @@ -653,12 +653,12 @@ def test_get_model_fine_tuned(
}

@pytest.mark.parametrize(
("artifact_location_set", "download_from_hf"),
("artifact_location_set", "download_from_hf", "cleanup_model_cache"),
[
(True, True),
(True, False),
(False, True),
(False, False),
(True, True, True),
(True, False, True),
(False, True, False),
(False, False, True),
],
)
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
Expand All @@ -683,6 +683,7 @@ def test_import_verified_model(
mock_ocidsc_create,
artifact_location_set,
download_from_hf,
cleanup_model_cache,
mock_get_hf_model_info,
mock_init_client,
):
Expand Down Expand Up @@ -747,6 +748,7 @@ def test_import_verified_model(
os_path=os_path,
local_dir=str(tmpdir),
download_from_hf=True,
cleanup_model_cache=cleanup_model_cache,
allow_patterns=["*.json"],
ignore_patterns=["test.json"],
)
Expand All @@ -761,6 +763,20 @@ def test_import_verified_model(
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*"
)
)
if cleanup_model_cache:
cache_dir = os.path.join(
os.path.expanduser("~"),
".cache",
"huggingface",
"hub",
"models--oracle--aqua-1t-mega-model",
)
assert (
os.path.exists(f"{str(tmpdir)}/{os.path.dirname(model_name)}")
is False
)
assert os.path.exists(cache_dir) is False

else:
model: AquaModel = app.register(
model="ocid1.datasciencemodel.xxx.xxxx.",
Expand Down Expand Up @@ -1183,22 +1199,22 @@ def test_import_model_with_input_tags(
"model": "oracle/oracle-1it",
"inference_container": "odsc-vllm-serving",
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving",
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving",
),
(
{
"os_path": "oci://aqua-bkt@aqua-ns/path",
"model": "ocid1.datasciencemodel.oc1.iad.<OCID>",
},
"ads aqua model register --model ocid1.datasciencemodel.oc1.iad.<OCID> --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True",
"ads aqua model register --model ocid1.datasciencemodel.oc1.iad.<OCID> --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True",
),
(
{
"os_path": "oci://aqua-bkt@aqua-ns/path",
"model": "oracle/oracle-1it",
"download_from_hf": False,
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False",
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False --cleanup_model_cache True",
),
(
{
Expand All @@ -1207,7 +1223,7 @@ def test_import_model_with_input_tags(
"download_from_hf": True,
"model_file": "test_model_file",
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --model_file test_model_file",
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --model_file test_model_file",
),
(
{
Expand All @@ -1216,7 +1232,7 @@ def test_import_model_with_input_tags(
"inference_container": "odsc-tei-serving",
"inference_container_uri": "<region>.ocir.io/<your_tenancy>/<your_image>",
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-tei-serving --inference_container_uri <region>.ocir.io/<your_tenancy>/<your_image>",
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-tei-serving --inference_container_uri <region>.ocir.io/<your_tenancy>/<your_image>",
),
(
{
Expand All @@ -1227,7 +1243,7 @@ def test_import_model_with_input_tags(
"defined_tags": {"dtag1": "dvalue1", "dtag2": "dvalue2"},
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path "
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
"--download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --freeform_tags "
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
),
],
Expand Down
4 changes: 3 additions & 1 deletion tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from unittest import TestCase
Expand Down Expand Up @@ -213,6 +213,8 @@ def test_register(
project_id=None,
model_file=model_file,
download_from_hf=download_from_hf,
local_dir=None,
cleanup_model_cache=True,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
Expand Down
Loading