diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 6e1e09aca..001b22e8c 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """AQUA utils and constants.""" @@ -11,6 +11,7 @@ import random import re import shlex +import shutil import subprocess from datetime import datetime, timedelta from functools import wraps @@ -21,6 +22,8 @@ import fsspec import oci from cachetools import TTLCache, cached +from huggingface_hub.constants import HF_HUB_CACHE +from huggingface_hub.file_download import repo_folder_name from huggingface_hub.hf_api import HfApi, ModelInfo from huggingface_hub.utils import ( GatedRepoError, @@ -788,7 +791,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str: return ocid[-key_len:] if ocid and len(ocid) > key_len else "" -def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str: +def upload_folder( + os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None +) -> str: """Upload the local folder to the object storage Args: @@ -818,6 +823,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path +def cleanup_local_hf_model_artifact( + model_name: str, + local_dir: str = None, +): + """ + Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space. + Parameters + ---------- + model_name (str): Name of the huggingface model + local_dir (str): Local directory where the object is downloaded + + """ + if local_dir and os.path.exists(local_dir): + model_dir = os.path.join(local_dir, model_name) + model_dir = ( + os.path.dirname(model_dir) + if "/" in model_name or os.sep in model_name + else model_dir + ) + shutil.rmtree(model_dir, ignore_errors=True) + if os.path.exists(model_dir): + logger.debug( + f"Could not delete local model artifact directory: {model_dir}" + ) + else: + logger.debug(f"Deleted local model artifact directory: {model_dir}.") + + hf_local_path = os.path.join( + HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model") + ) + shutil.rmtree(hf_local_path, ignore_errors=True) + + if os.path.exists(hf_local_path): + logger.debug( + f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}." + ) + else: + logger.debug( + f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}." + ) + + def is_service_managed_container(container): return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME) diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 42f90ffef..3024dc392 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from typing import Optional @@ -128,6 +128,10 @@ def post(self, *args, **kwargs): # noqa: ARG002 download_from_hf = ( str(input_data.get("download_from_hf", "false")).lower() == "true" ) + local_dir = input_data.get("local_dir") + cleanup_model_cache = ( + str(input_data.get("cleanup_model_cache", "true")).lower() == "true" + ) inference_container_uri = input_data.get("inference_container_uri") allow_patterns = input_data.get("allow_patterns") ignore_patterns = input_data.get("ignore_patterns") @@ -139,6 +143,8 @@ def post(self, *args, **kwargs): # noqa: ARG002 model=model, os_path=os_path, download_from_hf=download_from_hf, + local_dir=local_dir, + cleanup_model_cache=cleanup_model_cache, inference_container=inference_container, finetuning_container=finetuning_container, compartment_id=compartment_id, diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index ecdb8b8e7..6dd5eba21 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """ @@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin): os_path: str download_from_hf: Optional[bool] = True local_dir: Optional[str] = None + cleanup_model_cache: Optional[bool] = True inference_container: Optional[str] = None finetuning_container: Optional[str] = None compartment_id: Optional[str] = None diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..16e1865a4 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os import pathlib @@ -23,6 +23,7 @@ from ads.aqua.common.utils import ( LifecycleStatus, _build_resource_identifier, + cleanup_local_hf_model_artifact, copy_model_config, create_word_icon, generate_tei_cmd_var, @@ -1322,20 +1323,18 @@ def _download_model_from_hf( Returns ------- model_artifact_path (str): Location where the model artifacts are downloaded. - """ # Download the model from hub - if not local_dir: - local_dir = os.path.join(os.path.expanduser("~"), "cached-model") - local_dir = os.path.join(local_dir, model_name) - os.makedirs(local_dir, exist_ok=True) + if local_dir: + local_dir = os.path.join(local_dir, model_name) + os.makedirs(local_dir, exist_ok=True) snapshot_download( repo_id=model_name, local_dir=local_dir, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - # Upload to object storage and skip .cache/huggingface/ folder + # Upload to object storage model_artifact_path = upload_folder( os_path=os_path, local_dir=local_dir, @@ -1365,6 +1364,8 @@ def register( ignore_patterns (list): Model files matching any of the patterns are not downloaded. Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`. Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html + cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully + registered. Set to True by default. Returns: AquaModel: @@ -1474,6 +1475,14 @@ def register( detail=validation_result.telemetry_model_name, ) + if ( + import_model_details.download_from_hf + and import_model_details.cleanup_model_cache + ): + cleanup_local_hf_model_artifact( + model_name=model_name, local_dir=import_model_details.local_dir + ) + return AquaModel(**aqua_model_attributes) def _if_show(self, model: DataScienceModel) -> bool: diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index cabb8c523..158646bdb 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os @@ -653,12 +653,12 @@ def test_get_model_fine_tuned( } @pytest.mark.parametrize( - ("artifact_location_set", "download_from_hf"), + ("artifact_location_set", "download_from_hf", "cleanup_model_cache"), [ - (True, True), - (True, False), - (False, True), - (False, False), + (True, True, True), + (True, False, True), + (False, True, False), + (False, False, True), ], ) @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") @@ -683,6 +683,7 @@ def test_import_verified_model( mock_ocidsc_create, artifact_location_set, download_from_hf, + cleanup_model_cache, mock_get_hf_model_info, mock_init_client, ): @@ -747,6 +748,7 @@ def test_import_verified_model( os_path=os_path, local_dir=str(tmpdir), download_from_hf=True, + cleanup_model_cache=cleanup_model_cache, allow_patterns=["*.json"], ignore_patterns=["test.json"], ) @@ -761,6 +763,20 @@ def test_import_verified_model( f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*" ) ) + if cleanup_model_cache: + cache_dir = os.path.join( + os.path.expanduser("~"), + ".cache", + "huggingface", + "hub", + "models--oracle--aqua-1t-mega-model", + ) + assert ( + os.path.exists(f"{str(tmpdir)}/{os.path.dirname(model_name)}") + is False + ) + assert os.path.exists(cache_dir) is False + else: model: AquaModel = app.register( model="ocid1.datasciencemodel.xxx.xxxx.", @@ -1183,14 +1199,14 @@ def test_import_model_with_input_tags( "model": "oracle/oracle-1it", "inference_container": "odsc-vllm-serving", }, - "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving", + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving", ), ( { "os_path": "oci://aqua-bkt@aqua-ns/path", "model": "ocid1.datasciencemodel.oc1.iad.", }, - "ads aqua model register --model ocid1.datasciencemodel.oc1.iad. --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True", + "ads aqua model register --model ocid1.datasciencemodel.oc1.iad. --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True", ), ( { @@ -1198,7 +1214,7 @@ def test_import_model_with_input_tags( "model": "oracle/oracle-1it", "download_from_hf": False, }, - "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False", + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False --cleanup_model_cache True", ), ( { @@ -1207,7 +1223,7 @@ def test_import_model_with_input_tags( "download_from_hf": True, "model_file": "test_model_file", }, - "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --model_file test_model_file", + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --model_file test_model_file", ), ( { @@ -1216,7 +1232,7 @@ def test_import_model_with_input_tags( "inference_container": "odsc-tei-serving", "inference_container_uri": ".ocir.io//", }, - "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-tei-serving --inference_container_uri .ocir.io//", + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-tei-serving --inference_container_uri .ocir.io//", ), ( { @@ -1227,7 +1243,7 @@ def test_import_model_with_input_tags( "defined_tags": {"dtag1": "dvalue1", "dtag2": "dvalue2"}, }, "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path " - "--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags " + "--download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --freeform_tags " '{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}', ), ], diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index bf02174b9..391f6a19d 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from unittest import TestCase @@ -213,6 +213,8 @@ def test_register( project_id=None, model_file=model_file, download_from_hf=download_from_hf, + local_dir=None, + cleanup_model_cache=True, inference_container_uri=inference_container_uri, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns,