diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 284b8dff9f..734977ceff 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -55,10 +55,19 @@ from vertexai.preview import language_models from google.cloud.aiplatform_v1 import Execution as GapicExecution +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec, +) _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + _TEXT_BISON_PUBLISHER_MODEL_DICT = { "name": "publishers/google/models/text-bison", "version_id": "001", @@ -166,21 +175,44 @@ "dag": {"tasks": {}}, "inputDefinitions": { "parameters": { - "project": {"parameterType": "STRING"}, - "location": { + "api_endpoint": { + "defaultValue": "aiplatform.googleapis.com/ui", + "isOptional": True, "parameterType": "STRING", }, - "large_model_reference": { + "dataset_name": { + "defaultValue": "", + "isOptional": True, + "parameterType": "STRING", + }, + "dataset_uri": { + "defaultValue": "", + "isOptional": True, "parameterType": "STRING", }, - "model_display_name": { + "encryption_spec_key_name": { + "defaultValue": "", + "isOptional": True, "parameterType": "STRING", }, + "large_model_reference": { + "defaultValue": "text-bison-001", + "isOptional": True, + "parameterType": "STRING", + }, + "learning_rate": { + "defaultValue": 3, + "isOptional": True, + "parameterType": "NUMBER_DOUBLE", + }, + "location": {"parameterType": "STRING"}, + "model_display_name": {"parameterType": "STRING"}, + "project": {"parameterType": "STRING"}, "train_steps": { + "defaultValue": 1000, + "isOptional": True, "parameterType": "NUMBER_INTEGER", }, - "dataset_uri": {"parameterType": "STRING"}, - "dataset_name": {"parameterType": "STRING"}, } }, }, @@ -480,6 +512,7 @@ def test_tune_model( aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) with mock.patch.object( target=model_garden_service_client_v1beta1.ModelGardenServiceClient, @@ -497,6 +530,11 @@ def test_tune_model( tuning_job_location="europe-west4", tuned_model_location="us-central1", ) + call_kwargs = mock_pipeline_service_create.call_args[1] + assert ( + call_kwargs["pipeline_job"].encryption_spec.kms_key_name + == _TEST_ENCRYPTION_KEY_NAME + ) @pytest.mark.usefixtures( "get_model_with_tuned_version_label_mock", @@ -518,7 +556,6 @@ def test_get_tuned_model( _TEXT_BISON_PUBLISHER_MODEL_DICT ), ): - tuned_model = language_models.TextGenerationModel.get_tuned_model( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME ) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index d873899806..7a2dd04bcc 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -842,6 +842,10 @@ def _launch_tuning_job_on_jsonl_data( pipeline_arguments["dataset_name"] = dataset_name_or_uri if dataset_name_or_uri.startswith("gs://"): pipeline_arguments["dataset_uri"] = dataset_name_or_uri + if aiplatform_initializer.global_config.encryption_spec_key_name: + pipeline_arguments["encryption_spec_key_name"] = ( + aiplatform_initializer.global_config.encryption_spec_key_name + ) job = aiplatform.PipelineJob( template_path=tuning_pipeline_uri, display_name=None,