Skip to content

Commit

Permalink
feat: Add incremental training to AutoMLImageTrainingJob.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 517272484
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Mar 17, 2023
1 parent 091d74f commit bb92380
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
21 changes: 21 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5270,6 +5270,7 @@ def __init__(
multi_label: bool = False,
model_type: str = "CLOUD",
base_model: Optional[models.Model] = None,
incremental_train_base_model: Optional[models.Model] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand Down Expand Up @@ -5335,6 +5336,12 @@ def __init__(
Otherwise, the new model will be trained from scratch. The `base` model
must be in the same Project and Location as the new Model to train,
and have the same model_type.
incremental_train_base_model: Optional[models.Model] = None
Optional for both Image Classification and Object detection models, to
incrementally train a new model using an existing model as the starting point, with
a reduced training time. If not specified, the new model will be trained from scratch.
The `base` model must be in the same Project and Location as the new Model to train,
and have the same prediction_type and model_type.
project (str):
Optional. Project to run training in. Overrides project set in aiplatform.init.
location (str):
Expand Down Expand Up @@ -5423,6 +5430,7 @@ def __init__(
self._prediction_type = prediction_type
self._multi_label = multi_label
self._base_model = base_model
self._incremental_train_base_model = incremental_train_base_model

def run(
self,
Expand Down Expand Up @@ -5603,6 +5611,7 @@ def run(
return self._run(
dataset=dataset,
base_model=self._base_model,
incremental_train_base_model=self._incremental_train_base_model,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
Expand All @@ -5627,6 +5636,7 @@ def _run(
self,
dataset: datasets.ImageDataset,
base_model: Optional[models.Model] = None,
incremental_train_base_model: Optional[models.Model] = None,
training_fraction_split: Optional[float] = None,
validation_fraction_split: Optional[float] = None,
test_fraction_split: Optional[float] = None,
Expand Down Expand Up @@ -5681,6 +5691,12 @@ def _run(
Otherwise, the new model will be trained from scratch. The `base` model
must be in the same Project and Location as the new Model to train,
and have the same model_type.
incremental_train_base_model: Optional[models.Model] = None
Optional for both Image Classification and Object detection models, to
incrementally train a new model using an existing model as the starting point, with
a reduced training time. If not specified, the new model will be trained from scratch.
The `base` model must be in the same Project and Location as the new Model to train,
and have the same prediction_type and model_type.
model_id (str):
Optional. The ID to use for the Model produced by this job,
which will become the final component of the model resource name.
Expand Down Expand Up @@ -5818,6 +5834,11 @@ def _run(
# Set ID of Vertex AI Model to base this training job off of
training_task_inputs_dict["baseModelId"] = base_model.name

if incremental_train_base_model:
training_task_inputs_dict[
"uptrainBaseModelId"
] = incremental_train_base_model.name

return self._run_job(
training_task_definition=training_task_definition,
training_task_inputs=training_task_inputs_dict,
Expand Down
44 changes: 34 additions & 10 deletions tests/unit/aiplatform/test_automl_image_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@
struct_pb2.Value(),
)

_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL = json_format.ParseDict(
{
"modelType": "CLOUD",
"budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
"multiLabel": False,
"disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING,
"uptrainBaseModelId": _TEST_MODEL_ID,
},
struct_pb2.Value(),
)

_TEST_FRACTION_SPLIT_TRAINING = 0.6
_TEST_FRACTION_SPLIT_VALIDATION = 0.2
_TEST_FRACTION_SPLIT_TEST = 0.2
Expand Down Expand Up @@ -213,6 +224,20 @@ def mock_model():
yield model


@pytest.fixture
def mock_uptrain_base_model():
model = mock.MagicMock(models.Model)
model.name = _TEST_MODEL_ID
model._latest_future = None
model._exception = None
model._gca_resource = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description="This is the mock uptrain base Model's description",
name=_TEST_MODEL_NAME,
)
yield model


@pytest.mark.usefixtures("google_auth_mock")
class TestAutoMLImageTrainingJob:
def setup_method(self):
Expand All @@ -223,7 +248,7 @@ def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_init_all_parameters(self, mock_model):
"""Ensure all private members are set correctly at initialization"""
"""Ensure all private members are set correctly at initialization."""

aiplatform.init(project=_TEST_PROJECT)

Expand Down Expand Up @@ -275,7 +300,7 @@ def test_run_call_pipeline_service_create(
mock_pipeline_service_get,
mock_dataset_image,
mock_model_service_get,
mock_model,
mock_uptrain_base_model,
sync,
):
"""Create and run an AutoML ICN training job, verify calls and return value"""
Expand All @@ -287,7 +312,7 @@ def test_run_call_pipeline_service_create(

job = training_jobs.AutoMLImageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
base_model=mock_model,
incremental_train_base_model=mock_uptrain_base_model,
labels=_TEST_LABELS,
)

Expand Down Expand Up @@ -315,8 +340,7 @@ def test_run_call_pipeline_service_create(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
labels=mock_model._gca_resource.labels,
description=mock_model._gca_resource.description,
labels=_TEST_MODEL_LABELS,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)
Expand All @@ -330,7 +354,7 @@ def test_run_call_pipeline_service_create(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
training_task_definition=schema.training_job.definition.automl_image_classification,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
Expand Down Expand Up @@ -754,7 +778,7 @@ def test_splits_default(
mock_pipeline_service_get,
mock_dataset_image,
mock_model_service_get,
mock_model,
mock_uptrain_base_model,
sync,
):
"""
Expand All @@ -768,7 +792,8 @@ def test_splits_default(
)

job = training_jobs.AutoMLImageTrainingJob(
display_name=_TEST_DISPLAY_NAME, base_model=mock_model
display_name=_TEST_DISPLAY_NAME,
incremental_train_base_model=mock_uptrain_base_model,
)

model_from_job = job.run(
Expand All @@ -785,7 +810,6 @@ def test_splits_default(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)
Expand All @@ -797,7 +821,7 @@ def test_splits_default(
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.automl_image_classification,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
Expand Down

0 comments on commit bb92380

Please sign in to comment.