From 2cf47c934ff527a6e88b7a05357c42307b1dfc14 Mon Sep 17 00:00:00 2001 From: Michael Hu Date: Tue, 17 May 2022 09:48:54 -0400 Subject: [PATCH] Add seq2seq job to init file. --- google/cloud/aiplatform/__init__.py | 2 ++ google/cloud/aiplatform/schema.py | 2 +- google/cloud/aiplatform/training_jobs.py | 16 +++++++++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index db7d0a7c18a..1ad69b2a540 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -63,6 +63,7 @@ CustomPythonPackageTrainingJob, AutoMLTabularTrainingJob, AutoMLForecastingTrainingJob, + SequenceToSequencePlusForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTextTrainingJob, AutoMLVideoTrainingJob, @@ -116,6 +117,7 @@ "Model", "ModelEvaluation", "PipelineJob", + "SequenceToSequencePlusForecastingTrainingJob", "TabularDataset", "Tensorboard", "TensorboardExperiment", diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py index 8c8e7f32f3a..96a7a50bbde 100644 --- a/google/cloud/aiplatform/schema.py +++ b/google/cloud/aiplatform/schema.py @@ -23,7 +23,7 @@ class definition: custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml" automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml" automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml" - seq2seq_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml" + seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml" automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 45834b46c5c..355c93397d7 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1693,7 +1693,7 @@ def _training_task_definition(cls) -> str: """A GCS path to the YAML file that defines the training task. The definition files that can be used here are found in - gs://google-cloud- aiplatform/schema/trainingjob/definition/. + gs://google-cloud-aiplatform/schema/trainingjob/definition/. """ pass @@ -1907,13 +1907,13 @@ def run( if self._is_waiting_to_run(): raise RuntimeError( - f"{self.__class__._model_type} Forecasting Training is already " - "scheduled to run." + f"{self._model_type} Forecasting Training is already scheduled " + "to run." ) if self._has_run: raise RuntimeError( - f"{self.__class__._model_type} Forecasting Training has already run." + f"{self._model_type} Forecasting Training has already run." ) if additional_experiments: @@ -2218,7 +2218,7 @@ def _run( ) new_model = self._run_job( - training_task_definition=self.__class__._training_task_definition, + training_task_definition=self._training_task_definition, training_task_inputs=training_task_inputs_dict, dataset=dataset, training_fraction_split=training_fraction_split, @@ -4961,8 +4961,10 @@ def evaluated_data_items_bigquery_uri(self) -> Optional[str]: class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob): _model_type = "Seq2Seq" - _training_task_definition = schema.training_job.definition.seq2seq_forecasting - _supported_training_schemas = (schema.training_job.definition.seq2seq_forecasting,) + _training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting + _supported_training_schemas = ( + schema.training_job.definition.seq2seq_plus_forecasting, + ) def __init__( self,