Skip to content

Commit

Permalink
feat: LLM - Added tuning support for codechat-bison models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555829035
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 11, 2023
1 parent 3a97c52 commit af6e455
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
47 changes: 47 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,53 @@ def test_tune_chat_model(
].runtime_config.parameter_values
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
)
@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_tune_code_chat_model(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
mock_gcs_from_string,
mock_gcs_upload,
mock_request_urlopen,
mock_get_tuned_model,
):
"""Tests tuning a code chat model."""
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.CodeChatModel.from_pretrained(
"codechat-bison@001"
)

# The tune_model call needs to be inside the PublisherModel mock
# since it gets a new PublisherModel when tuning completes.
model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["large_model_reference"] == "codechat-bison@001"

@pytest.mark.usefixtures(
"get_model_with_tuned_version_label_mock",
"get_endpoint_with_models_mock",
Expand Down
1 change: 1 addition & 0 deletions vertexai/_model_garden/_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
"chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
"codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
}

_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(
Expand Down
4 changes: 4 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,10 @@ def start_chat(
)


class _PreviewCodeChatModel(CodeChatModel, _TunableModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


class _ChatSessionBase:
"""_ChatSessionBase is a base class for all chat sessions."""

Expand Down
3 changes: 2 additions & 1 deletion vertexai/preview/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

from vertexai.language_models._language_models import (
_PreviewChatModel,
_PreviewCodeChatModel,
_PreviewCodeGenerationModel,
_PreviewTextEmbeddingModel,
_PreviewTextGenerationModel,
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
TextEmbedding,
TextGenerationResponse,
)

ChatModel = _PreviewChatModel
CodeChatModel = _PreviewCodeChatModel
CodeGenerationModel = _PreviewCodeGenerationModel
TextGenerationModel = _PreviewTextGenerationModel
TextEmbeddingModel = _PreviewTextEmbeddingModel
Expand Down

0 comments on commit af6e455

Please sign in to comment.