From af6e45556d6b093189f363a95f2be45e0008aebd Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 11 Aug 2023 00:27:07 -0700 Subject: [PATCH] feat: LLM - Added tuning support for `codechat-bison` models PiperOrigin-RevId: 555829035 --- tests/unit/aiplatform/test_language_models.py | 47 +++++++++++++++++++ .../_model_garden/_model_garden_models.py | 1 + vertexai/language_models/_language_models.py | 4 ++ vertexai/preview/language_models.py | 3 +- 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index fe04517fed..4d5e286070 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -724,6 +724,53 @@ def test_tune_chat_model( ].runtime_config.parameter_values assert pipeline_arguments["large_model_reference"] == "chat-bison@001" + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON], + ) + @pytest.mark.parametrize( + "mock_request_urlopen", + ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], + indirect=True, + ) + def test_tune_code_chat_model( + self, + mock_pipeline_service_create, + mock_pipeline_job_get, + mock_pipeline_bucket_exists, + job_spec, + mock_load_yaml_and_json, + mock_gcs_from_string, + mock_gcs_upload, + mock_request_urlopen, + mock_get_tuned_model, + ): + """Tests tuning a code chat model.""" + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODECHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.CodeChatModel.from_pretrained( + "codechat-bison@001" + ) + + # The tune_model call needs to be inside the PublisherModel mock + # since it gets a new PublisherModel when tuning completes. + model.tune_model( + training_data=_TEST_TEXT_BISON_TRAINING_DF, + tuning_job_location="europe-west4", + tuned_model_location="us-central1", + ) + call_kwargs = mock_pipeline_service_create.call_args[1] + pipeline_arguments = call_kwargs[ + "pipeline_job" + ].runtime_config.parameter_values + assert pipeline_arguments["large_model_reference"] == "codechat-bison@001" + @pytest.mark.usefixtures( "get_model_with_tuned_version_label_mock", "get_endpoint_with_models_mock", diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index 0fa77edb54..a32a5d2b7b 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -34,6 +34,7 @@ "text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0", "code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0", "chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", + "codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", } _SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 8aa67ec92f..4cd02a1142 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -739,6 +739,10 @@ def start_chat( ) +class _PreviewCodeChatModel(CodeChatModel, _TunableModelMixin): + _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE + + class _ChatSessionBase: """_ChatSessionBase is a base class for all chat sessions.""" diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py index 057d73fdcd..6ecf2a6d54 100644 --- a/vertexai/preview/language_models.py +++ b/vertexai/preview/language_models.py @@ -16,13 +16,13 @@ from vertexai.language_models._language_models import ( _PreviewChatModel, + _PreviewCodeChatModel, _PreviewCodeGenerationModel, _PreviewTextEmbeddingModel, _PreviewTextGenerationModel, ChatMessage, ChatModel, ChatSession, - CodeChatModel, CodeChatSession, InputOutputTextPair, TextEmbedding, @@ -30,6 +30,7 @@ ) ChatModel = _PreviewChatModel +CodeChatModel = _PreviewCodeChatModel CodeGenerationModel = _PreviewCodeGenerationModel TextGenerationModel = _PreviewTextGenerationModel TextEmbeddingModel = _PreviewTextEmbeddingModel