Skip to content

Commit

Permalink
make e2e test parameterized
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed Jun 3, 2022
1 parent e5b08d3 commit 09f6799
Showing 1 changed file with 30 additions and 54 deletions.
84 changes: 30 additions & 54 deletions tests/system/aiplatform/test_e2e_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from google.cloud import aiplatform
from google.cloud.aiplatform import training_jobs
from google.cloud.aiplatform.compat.types import job_state
from google.cloud.aiplatform.compat.types import pipeline_state
import pytest
Expand All @@ -35,7 +36,16 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):

_temp_prefix = "temp-vertex-sdk-e2e-forecasting"

def test_end_to_end_forecasting(self, shared_state):
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
pytest.param(
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
marks=pytest.mark.skip(reason="Seq2Seq not yet released.")),
],
)
def test_end_to_end_forecasting(self, shared_state, training_job):
"""Builds a dataset, trains models, and gets batch predictions."""
resources = []

Expand All @@ -45,14 +55,13 @@ def test_end_to_end_forecasting(self, shared_state):
staging_bucket=shared_state["staging_bucket_name"],
)
try:
# Create and import to single managed dataset for both training
# jobs.
ds = aiplatform.TimeSeriesDataset.create(
display_name=self._make_display_name("dataset"),
bq_source=[_TRAINING_DATASET_BQ_PATH],
sync=False,
create_request_timeout=180.0,
)
resources.append(ds)

time_column = "date"
time_series_identifier_column = "store_name"
Expand All @@ -65,22 +74,15 @@ def test_end_to_end_forecasting(self, shared_state):
"county": "categorical",
}

# Define both training jobs
automl_job = aiplatform.AutoMLForecastingTrainingJob(
display_name=self._make_display_name("train-housing-automl"),
job = training_job(
display_name=self._make_display_name(
"train-housing-forecasting"),
optimization_objective="minimize-rmse",
column_specs=column_specs,
)
seq2seq_job = aiplatform.SequenceToSequencePlusForecastingTrainingJob(
display_name=self._make_display_name("train-housing-seq2seq"),
optimization_objective="minimize-rmse",
column_specs=column_specs,
)
resources.extend([automl_job, seq2seq_job])
resources.append(job)

# Kick off both training jobs, AutoML job will take approx one hour
# to run.
automl_model = automl_job.run(
model = job.run(
dataset=ds,
target_column=target_column,
time_column=time_column,
Expand All @@ -93,29 +95,18 @@ def test_end_to_end_forecasting(self, shared_state):
data_granularity_unit="day",
data_granularity_count=1,
budget_milli_node_hours=1000,
model_display_name=self._make_display_name("automl-liquor-model"),
holiday_regions=["GLOBAL"],
hierarchy_group_total_weight=1,
window_stride_length=1,
model_display_name=self._make_display_name(
"forecasting-liquor-model"),
sync=False,
)
seq2seq_model = seq2seq_job.run(
dataset=ds,
target_column=target_column,
time_column=time_column,
time_series_identifier_column=time_series_identifier_column,
available_at_forecast_columns=[time_column],
unavailable_at_forecast_columns=[target_column],
time_series_attribute_columns=["city", "zip_code", "county"],
forecast_horizon=30,
context_window=30,
data_granularity_unit="day",
data_granularity_count=1,
budget_milli_node_hours=1000,
model_display_name=self._make_display_name("seq2seq-liquor-model"),
sync=False,
)
resources.extend([automl_model, seq2seq_model])
resources.append(model)

automl_batch_prediction_job = automl_model.batch_predict(
job_display_name=self._make_display_name("automl-liquor-model"),
batch_prediction_job = model.batch_predict(
job_display_name=self._make_display_name(
"forecasting-liquor-model"),
instances_format="bigquery",
machine_type="n1-standard-4",
bigquery_source=_PREDICTION_DATASET_BQ_PATH,
Expand All @@ -124,32 +115,17 @@ def test_end_to_end_forecasting(self, shared_state):
),
sync=False,
)
seq2seq_batch_prediction_job = seq2seq_model.batch_predict(
job_display_name=self._make_display_name("seq2seq-liquor-model"),
instances_format="bigquery",
machine_type="n1-standard-4",
bigquery_source=_PREDICTION_DATASET_BQ_PATH,
gcs_destination_prefix=(
f'gs://{shared_state["staging_bucket_name"]}/bp_results/'
),
sync=False,
)
resources.extend(
[automl_batch_prediction_job, seq2seq_batch_prediction_job]
)

automl_batch_prediction_job.wait()
seq2seq_batch_prediction_job.wait()
resources.append(batch_prediction_job)

batch_prediction_job.wait()
assert (
automl_job.state
job.state
== pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)
assert (
automl_batch_prediction_job.state
batch_prediction_job.state
== job_state.JobState.JOB_STATE_SUCCEEDED
)
finally:
for resource in resources:
resource.wait_for_resource_creation()
resource.delete()

0 comments on commit 09f6799

Please sign in to comment.