Skip to content

Commit

Permalink
feat: LLM - Added support for CMEK in tuning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 536899267
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 1, 2023
1 parent 056b0bd commit aebf74a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
51 changes: 44 additions & 7 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,19 @@

from vertexai.preview import language_models
from google.cloud.aiplatform_v1 import Execution as GapicExecution
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
)

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"

# CMEK encryption
_TEST_ENCRYPTION_KEY_NAME = "key_1234"
_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_KEY_NAME
)

_TEXT_BISON_PUBLISHER_MODEL_DICT = {
"name": "publishers/google/models/text-bison",
"version_id": "001",
Expand Down Expand Up @@ -166,21 +175,44 @@
"dag": {"tasks": {}},
"inputDefinitions": {
"parameters": {
"project": {"parameterType": "STRING"},
"location": {
"api_endpoint": {
"defaultValue": "aiplatform.googleapis.com/ui",
"isOptional": True,
"parameterType": "STRING",
},
"large_model_reference": {
"dataset_name": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"dataset_uri": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"model_display_name": {
"encryption_spec_key_name": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"large_model_reference": {
"defaultValue": "text-bison-001",
"isOptional": True,
"parameterType": "STRING",
},
"learning_rate": {
"defaultValue": 3,
"isOptional": True,
"parameterType": "NUMBER_DOUBLE",
},
"location": {"parameterType": "STRING"},
"model_display_name": {"parameterType": "STRING"},
"project": {"parameterType": "STRING"},
"train_steps": {
"defaultValue": 1000,
"isOptional": True,
"parameterType": "NUMBER_INTEGER",
},
"dataset_uri": {"parameterType": "STRING"},
"dataset_name": {"parameterType": "STRING"},
}
},
},
Expand Down Expand Up @@ -480,6 +512,7 @@ def test_tune_model(
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
)
with mock.patch.object(
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
Expand All @@ -497,6 +530,11 @@ def test_tune_model(
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
assert (
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
== _TEST_ENCRYPTION_KEY_NAME
)

@pytest.mark.usefixtures(
"get_model_with_tuned_version_label_mock",
Expand All @@ -518,7 +556,6 @@ def test_get_tuned_model(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):

tuned_model = language_models.TextGenerationModel.get_tuned_model(
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
)
Expand Down
4 changes: 4 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,10 @@ def _launch_tuning_job_on_jsonl_data(
pipeline_arguments["dataset_name"] = dataset_name_or_uri
if dataset_name_or_uri.startswith("gs://"):
pipeline_arguments["dataset_uri"] = dataset_name_or_uri
if aiplatform_initializer.global_config.encryption_spec_key_name:
pipeline_arguments["encryption_spec_key_name"] = (
aiplatform_initializer.global_config.encryption_spec_key_name
)
job = aiplatform.PipelineJob(
template_path=tuning_pipeline_uri,
display_name=None,
Expand Down

0 comments on commit aebf74a

Please sign in to comment.