Skip to content

Commit

Permalink
Fixed test_automl_tabular_training_jobs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Aug 17, 2021
1 parent 78872ab commit 72e20f0
Showing 1 changed file with 20 additions and 32 deletions.
52 changes: 20 additions & 32 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,8 @@
_TEST_FRACTION_SPLIT_VALIDATION = 0.2
_TEST_FRACTION_SPLIT_TEST = 0.2

_TEST_SPLIT_DEFAULT = gca_training_pipeline.FractionSplit(
training_fraction=0.8, validation_fraction=0.1, test_fraction=0.1,
)

_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "timestamp"
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"

Expand Down Expand Up @@ -307,7 +303,7 @@ def test_run_call_pipeline_service_create(
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
timestamp_split_column_name=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
timestamp_split_column_name=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
Expand All @@ -321,7 +317,7 @@ def test_run_call_pipeline_service_create(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
key=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
key=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
)

true_managed_model = gca_model.Model(
Expand Down Expand Up @@ -392,15 +388,13 @@ def test_run_call_pipeline_if_no_model_display_name(
if not sync:
model_from_job.wait()

true_fraction_split = _TEST_SPLIT_DEFAULT

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -527,19 +521,13 @@ def test_run_call_pipeline_service_create_if_set_additional_experiments(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down Expand Up @@ -860,7 +848,6 @@ def test_splits_fraction(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand All @@ -884,8 +871,11 @@ def test_splits_fraction(

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
Expand All @@ -902,7 +892,6 @@ def test_splits_fraction(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

Expand Down Expand Up @@ -931,7 +920,6 @@ def test_splits_timestamp(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand All @@ -955,8 +943,11 @@ def test_splits_timestamp(

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
timestamp_split_column_name=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
Expand All @@ -966,7 +957,7 @@ def test_splits_timestamp(
if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.TimestampSplit(
true_split = gca_training_pipeline.TimestampSplit(
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
test_fraction=_TEST_FRACTION_SPLIT_TEST,
Expand All @@ -975,12 +966,11 @@ def test_splits_timestamp(

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name,
timestamp_split=true_split, dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand All @@ -1004,7 +994,6 @@ def test_splits_predefined(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand All @@ -1028,6 +1017,8 @@ def test_splits_predefined(

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
predefined_split_column_name=_TEST_SPLIT_PREDEFINED_COLUMN_NAME,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
Expand All @@ -1037,18 +1028,17 @@ def test_splits_predefined(
if not sync:
model_from_job.wait()

true_filter_split = gca_training_pipeline.PredefinedSplit(
true_split = gca_training_pipeline.PredefinedSplit(
key=_TEST_SPLIT_PREDEFINED_COLUMN_NAME
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
filter_split=true_filter_split, dataset_id=mock_dataset_tabular.name,
predefined_split=true_split, dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand All @@ -1072,7 +1062,6 @@ def test_splits_default(
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
mock_model,
sync,
):
"""
Expand All @@ -1096,6 +1085,8 @@ def test_splits_default(

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
Expand All @@ -1104,16 +1095,13 @@ def test_splits_default(
if not sync:
model_from_job.wait()

true_default_split = _TEST_SPLIT_DEFAULT

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_default_split, dataset_id=mock_dataset_tabular.name,
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
Expand Down

0 comments on commit 72e20f0

Please sign in to comment.