Skip to content

Commit

Permalink
feat: LLM - Added support for model distillation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590578502
  • Loading branch information
Ark-kun authored and copybara-github committed Dec 13, 2023
1 parent cfc5cba commit 28925e9
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 21 deletions.
234 changes: 234 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,124 @@ def reverse_string_2(s):""",
"pipelineSpec": json.loads(_TEST_EVAL_PIPELINE_SPEC_JSON),
}
)
_TEST_DISTILLATION_PIPELINE_SPEC = {
"components": {},
"pipelineInfo": {
"description": "Vertex kfp pipeline for distillation.",
"name": "distillation",
},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {
"parameters": {
"accelerator_type": {
"defaultValue": "GPU",
"isOptional": True,
"parameterType": "STRING",
},
"api_endpoint": {
"defaultValue": "aiplatform.googleapis.com/ui",
"isOptional": True,
"parameterType": "STRING",
},
"dataset_uri": {"parameterType": "STRING"},
"enable_checkpoint_selection": {
"defaultValue": "default",
"isOptional": True,
"parameterType": "STRING",
},
"enable_early_stopping": {
"defaultValue": True,
"isOptional": True,
"parameterType": "BOOLEAN",
},
"encryption_spec_key_name": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"evaluation_data_uri": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"evaluation_interval": {
"defaultValue": 100,
"isOptional": True,
"parameterType": "NUMBER_INTEGER",
},
"evaluation_output_root_dir": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"learning_rate_multiplier": {
"defaultValue": 1,
"isOptional": True,
"parameterType": "NUMBER_DOUBLE",
},
"location": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"max_context_length": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"model_display_name": {
"defaultValue": "distilled-student-model",
"isOptional": True,
"parameterType": "STRING",
},
"project": {"parameterType": "STRING"},
"student_model_reference": {
"defaultValue": "text-bison@002",
"isOptional": True,
"parameterType": "STRING",
},
"teacher_model_reference": {
"defaultValue": "text-unicorn@001",
"isOptional": True,
"parameterType": "STRING",
},
"temperature": {
"defaultValue": 0,
"isOptional": True,
"parameterType": "NUMBER_DOUBLE",
},
"tensorboard_resource_id": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
"tpu_training_skip_cmek": {
"defaultValue": False,
"isOptional": True,
"parameterType": "BOOLEAN",
},
"train_steps": {
"defaultValue": 300,
"isOptional": True,
"parameterType": "NUMBER_INTEGER",
},
"version": {
"defaultValue": "latest",
"isOptional": True,
"parameterType": "STRING",
},
}
},
},
"schemaVersion": "2.1.0",
"sdkVersion": "kfp-2.4.0",
}

_TEST_DISTILLATION_PIPELINE_SPEC_JSON = json.dumps(
_TEST_DISTILLATION_PIPELINE_SPEC,
)


# Eval classification spec

Expand Down Expand Up @@ -875,6 +993,10 @@ def reverse_string_2(s):""",
}
)

_URL_DATA = {
"https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0": _TEST_DISTILLATION_PIPELINE_SPEC_JSON,
}


@pytest.fixture
def mock_pipeline_bucket_exists():
Expand Down Expand Up @@ -1225,6 +1347,19 @@ def mock_request_urlopen_eval_classification(
yield request.param, mock_urlopen


@pytest.fixture
def mock_urllib_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
url = request.param
data = _URL_DATA[url]
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = data
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield url, mock_urlopen


@pytest.fixture
def get_endpoint_mock():
with mock.patch.object(
Expand Down Expand Up @@ -4251,3 +4386,102 @@ def test_model_evaluation_text_classification_base_model_only_summary_metrics(
)
assert eval_metrics.confidenceMetrics is None
assert eval_metrics.auPrc == _TEST_TEXT_CLASSIFICATION_METRICS["auPrc"]

@pytest.mark.parametrize(
"job_spec",
[
_TEST_DISTILLATION_PIPELINE_SPEC_JSON,
],
)
@pytest.mark.parametrize(
"mock_urllib_request_urlopen",
["https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0"],
indirect=True,
)
def test_text_generation_model_distill_from(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
mock_gcs_from_string,
mock_gcs_upload,
mock_urllib_request_urlopen,
mock_get_tuned_model,
):
"""Tests distilling the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

dataset_uri = "gs://bucket/distillation.training_data.jsonl"
evaluation_data_uri = "gs://bucket/eval.jsonl"
evaluation_interval = 37
enable_early_stopping = True
enable_checkpoint_selection = True
tensorboard_name = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/123"
)

tuning_job = model.distill_from(
dataset=dataset_uri,
teacher_model="text-unicorn@001",
learning_rate_multiplier=2.0,
train_steps=10,
evaluation_spec=preview_language_models.TuningEvaluationSpec(
evaluation_data=evaluation_data_uri,
evaluation_interval=evaluation_interval,
enable_early_stopping=enable_early_stopping,
enable_checkpoint_selection=enable_checkpoint_selection,
tensorboard=tensorboard_name,
),
accelerator_type="TPU",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["teacher_model_reference"] == "text-unicorn@001"
assert pipeline_arguments["student_model_reference"] == "text-bison@001"
assert pipeline_arguments["dataset_uri"] == dataset_uri
assert pipeline_arguments["project"] == _TEST_PROJECT
assert pipeline_arguments["location"] == _TEST_LOCATION
assert pipeline_arguments["train_steps"] == 10
assert pipeline_arguments["learning_rate_multiplier"] == 2.0
assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri
assert pipeline_arguments["evaluation_interval"] == evaluation_interval
assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
assert (
pipeline_arguments["enable_checkpoint_selection"]
== enable_checkpoint_selection
)
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
assert pipeline_arguments["accelerator_type"] == "TPU"
assert (
pipeline_arguments["encryption_spec_key_name"]
== _TEST_ENCRYPTION_KEY_NAME
)
assert (
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
== _TEST_ENCRYPTION_KEY_NAME
)

# Testing the tuned model
tuned_model = tuning_job.get_tuned_model()
assert (
tuned_model._endpoint_name
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)
111 changes: 111 additions & 0 deletions vertexai/language_models/_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional, Union

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer as aiplatform_initializer
from vertexai.language_models import _language_models
from vertexai.language_models import _language_models as tuning


class DistillationMixin:
_DISTILLATION_PIPELINE_URI = (
"https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0"
)

def distill_from(
self,
*,
dataset: str,
teacher_model: Union[str, _language_models._TextGenerationModel],
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
model_display_name: Optional[str] = None,
):
"""Tunes a smaller model with help from another bigger model.
Args:
dataset: A URI pointing to data in JSON lines format.
teacher_model: The teacher model to use for distillation.
train_steps: Number of training batches to use (batch size is 8 samples).
learning_rate_multiplier: Learning rate multiplier to use in tuning.
evaluation_spec: Specification for the model evaluation during tuning.
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
model_display_name: Custom display name for the tuned model.
Returns:
A tuning job for distillation.
Raises:
RuntimeError: If the model does not support distillation.
"""
if "/models/" not in self._endpoint_name:
raise RuntimeError(
f"Model does not support distillation: {self._endpoint_name}"
)
student_short_model_id = self._endpoint_name.split("/")[-1]

if isinstance(teacher_model, str):
teacher_short_model_id = teacher_model
elif isinstance(teacher_model, _language_models._LanguageModel):
if "/models/" not in teacher_model._endpoint_name:
raise RuntimeError(
f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
)
teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
else:
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")

pipeline_arguments = {
"teacher_model_reference": teacher_short_model_id,
"student_model_reference": student_short_model_id,
"dataset_uri": dataset,
"project": aiplatform_initializer.global_config.project,
"location": aiplatform_initializer.global_config.location,
}
if train_steps is not None:
pipeline_arguments["train_steps"] = train_steps
if learning_rate_multiplier is not None:
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
if evaluation_spec is not None:
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
pipeline_arguments[
"evaluation_interval"
] = evaluation_spec.evaluation_interval
pipeline_arguments[
"enable_early_stopping"
] = evaluation_spec.enable_early_stopping
pipeline_arguments[
"enable_checkpoint_selection"
] = evaluation_spec.enable_checkpoint_selection
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
if accelerator_type is not None:
pipeline_arguments["accelerator_type"] = accelerator_type
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
pipeline_arguments[
"encryption_spec_key_name"
] = aiplatform_initializer.global_config.encryption_spec_key_name
if model_display_name is None:
model_display_name = (
f"{student_short_model_id}"
f" distilled from {teacher_short_model_id}"
)
pipeline_arguments["model_display_name"] = model_display_name
# # Not exposing these parameters:
# temperature: Optional[float] = None,
# max_context_length: Optional[int] = None,
# tpu_training_skip_cmek: Optional[bool] = None,
# api_endpoint: Optional[str] = None,
# version: Optional[str] = None,
pipeline_job = aiplatform.PipelineJob(
template_path=self._DISTILLATION_PIPELINE_URI,
display_name=None,
parameter_values=pipeline_arguments,
)
pipeline_job.submit()
tuning_job = tuning._LanguageModelTuningJob(
base_model=self,
job=pipeline_job,
)
return tuning_job
Loading

0 comments on commit 28925e9

Please sign in to comment.