Skip to content

Commit

Permalink
Add seq2seq job to init file.
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed May 17, 2022
1 parent 950b827 commit 2cf47c9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
CustomPythonPackageTrainingJob,
AutoMLTabularTrainingJob,
AutoMLForecastingTrainingJob,
SequenceToSequencePlusForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
Expand Down Expand Up @@ -116,6 +117,7 @@
"Model",
"ModelEvaluation",
"PipelineJob",
"SequenceToSequencePlusForecastingTrainingJob",
"TabularDataset",
"Tensorboard",
"TensorboardExperiment",
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class definition:
custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
seq2seq_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
Expand Down
16 changes: 9 additions & 7 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,7 @@ def _training_task_definition(cls) -> str:
"""A GCS path to the YAML file that defines the training task.
The definition files that can be used here are found in
gs://google-cloud- aiplatform/schema/trainingjob/definition/.
gs://google-cloud-aiplatform/schema/trainingjob/definition/.
"""
pass

Expand Down Expand Up @@ -1907,13 +1907,13 @@ def run(

if self._is_waiting_to_run():
raise RuntimeError(
f"{self.__class__._model_type} Forecasting Training is already "
"scheduled to run."
f"{self._model_type} Forecasting Training is already scheduled "
"to run."
)

if self._has_run:
raise RuntimeError(
f"{self.__class__._model_type} Forecasting Training has already run."
f"{self._model_type} Forecasting Training has already run."
)

if additional_experiments:
Expand Down Expand Up @@ -2218,7 +2218,7 @@ def _run(
)

new_model = self._run_job(
training_task_definition=self.__class__._training_task_definition,
training_task_definition=self._training_task_definition,
training_task_inputs=training_task_inputs_dict,
dataset=dataset,
training_fraction_split=training_fraction_split,
Expand Down Expand Up @@ -4961,8 +4961,10 @@ def evaluated_data_items_bigquery_uri(self) -> Optional[str]:

class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
_model_type = "Seq2Seq"
_training_task_definition = schema.training_job.definition.seq2seq_forecasting
_supported_training_schemas = (schema.training_job.definition.seq2seq_forecasting,)
_training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting
_supported_training_schemas = (
schema.training_job.definition.seq2seq_plus_forecasting,
)

def __init__(
self,
Expand Down

0 comments on commit 2cf47c9

Please sign in to comment.