diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 52418096be..15ef20af74 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -308,10 +308,14 @@ def run(self) -> Optional[models.Model]: def _create_input_data_config( dataset: Optional[datasets._Dataset] = None, annotation_schema_uri: Optional[str] = None, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, gcs_destination_uri_prefix: Optional[str] = None, bigquery_destination: Optional[str] = None, ) -> Optional[gca_training_pipeline.InputDataConfig]: @@ -349,17 +353,35 @@ def _create_input_data_config( and ``annotation_schema_uri``. training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. - training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -370,6 +392,16 @@ def _create_input_data_config( ignored by the pipeline. Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split. gcs_destination_uri_prefix (str): Optional. The Google Cloud Storage location. @@ -396,33 +428,97 @@ def _create_input_data_config( - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + Raises: + ValueError: When more than 1 type of split configuration is passed or when + the split configuartion passed is incompatible with the dataset schema. """ input_data_config = None if dataset: - # Create fraction split spec - fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=training_fraction_split, - validation_fraction=validation_fraction_split, - test_fraction=test_fraction_split, - ) - - # Create predefined split spec + # Initialize all possible splits + filter_split = None predefined_split = None - if predefined_split_column_name: - if dataset._gca_resource.metadata_schema_uri not in ( - schema.dataset.metadata.tabular, - schema.dataset.metadata.time_series, + timestamp_split = None + fraction_split = None + + # Create filter split + if any( + [ + training_filter_split is not None, + validation_filter_split is not None, + test_filter_split is not None, + ] + ): + if all( + [ + training_filter_split is not None, + validation_filter_split is not None, + test_filter_split is not None, + ] ): + filter_split = gca_training_pipeline.FilterSplit( + training_filter=training_filter_split, + validation_filter=validation_filter_split, + test_filter=test_filter_split, + ) + else: raise ValueError( - "A pre-defined split may only be used with a tabular or time series Dataset" + "All filter splits must be passed together or not at all" ) + # Create predefined split + if predefined_split_column_name: predefined_split = gca_training_pipeline.PredefinedSplit( key=predefined_split_column_name ) - # Create GCS destination + # Create timestamp split or fraction split + if timestamp_split_column_name: + timestamp_split = gca_training_pipeline.TimestampSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + key=timestamp_split_column_name, + ) + elif any( + [ + training_fraction_split is not None, + validation_fraction_split is not None, + test_fraction_split is not None, + ] + ): + fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + splits = [ + split + for split in [ + filter_split, + predefined_split, + timestamp_split_column_name, + fraction_split, + ] + if split is not None + ] + + # Fallback to fraction split if nothing else is specified + if len(splits) == 0: + _LOGGER.info( + "No dataset split provided. The service will use a default split." + ) + elif len(splits) > 1: + raise ValueError( + """Can only specify one of: + 1. training_filter_split, validation_filter_split, test_filter_split + 2. predefined_split_column_name + 3. timestamp_split_column_name, training_fraction_split, validation_fraction_split, test_fraction_split + 4. training_fraction_split, validation_fraction_split, test_fraction_split""" + ) + + # create GCS destination gcs_destination = None if gcs_destination_uri_prefix: gcs_destination = gca_io.GcsDestination( @@ -439,7 +535,9 @@ def _create_input_data_config( # create input data config input_data_config = gca_training_pipeline.InputDataConfig( fraction_split=fraction_split, + filter_split=filter_split, predefined_split=predefined_split, + timestamp_split=timestamp_split, dataset_id=dataset.name, annotation_schema_uri=annotation_schema_uri, gcs_destination=gcs_destination, @@ -453,11 +551,15 @@ def _run_job( training_task_definition: str, training_task_inputs: Union[dict, proto.Message], dataset: Optional[datasets._Dataset], - training_fraction_split: float, - validation_fraction_split: float, - test_fraction_split: float, - annotation_schema_uri: Optional[str] = None, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, + annotation_schema_uri: Optional[str] = None, model: Optional[gca_model.Model] = None, gcs_destination_uri_prefix: Optional[str] = None, bigquery_destination: Optional[str] = None, @@ -488,15 +590,6 @@ def _run_job( [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular Datasets, all their data is exported to training, to pick and choose from. - training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. - validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. - test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. annotation_schema_uri (str): Google Cloud Storage URI points to a YAML file describing annotation schema. The schema is defined as an OpenAPI 3.0.2 @@ -519,6 +612,36 @@ def _run_job( ``annotations_filter`` and ``annotation_schema_uri``. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -529,6 +652,16 @@ def _run_job( ignored by the pipeline. Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split. model (~.model.Model): Optional. Describes the Model that may be uploaded (via [ModelService.UploadMode][]) by this TrainingPipeline. The @@ -583,7 +716,11 @@ def _run_job( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, gcs_destination_uri_prefix=gcs_destination_uri_prefix, bigquery_destination=bigquery_destination, ) @@ -1574,8 +1711,6 @@ def __init__( self._requirements = requirements self._script_path = script_path - # TODO(b/172365904) add filter split, training_pipeline.FilterSplit - # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit def run( self, dataset: Optional[ @@ -1601,10 +1736,14 @@ def run( accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -1616,12 +1755,36 @@ def run( ie: replica_count = 10 will result in 1 chief and 9 workers All replicas have same machine_type, accelerator_type, and accelerator_count - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI.If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset ( @@ -1745,14 +1908,35 @@ def run( Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -1762,6 +1946,15 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -1818,7 +2011,11 @@ def run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, tensorboard=tensorboard, sync=sync, ) @@ -1844,10 +2041,14 @@ def _run( service_account: Optional[str] = None, network: Optional[str] = None, bigquery_destination: Optional[str] = None, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -1918,14 +2119,35 @@ def _run( - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -1935,6 +2157,15 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -2001,7 +2232,11 @@ def _run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, model=managed_model, gcs_destination_uri_prefix=base_output_dir, bigquery_destination=bigquery_destination, @@ -2238,8 +2473,6 @@ def __init__( self._command = command - # TODO(b/172365904) add filter split, training_pipeline.FilterSplit - # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit def run( self, dataset: Optional[ @@ -2265,10 +2498,14 @@ def run( accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -2280,12 +2517,36 @@ def run( ie: replica_count = 10 will result in 1 chief and 9 workers All replicas have same machine_type, accelerator_type, and accelerator_count - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset]): @@ -2402,14 +2663,35 @@ def run( Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -2419,6 +2701,15 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -2474,7 +2765,11 @@ def run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, tensorboard=tensorboard, sync=sync, ) @@ -2499,10 +2794,14 @@ def _run( service_account: Optional[str] = None, network: Optional[str] = None, bigquery_destination: Optional[str] = None, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -2569,14 +2868,35 @@ def _run( - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -2586,6 +2906,15 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -2646,7 +2975,11 @@ def _run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, model=managed_model, gcs_destination_uri_prefix=base_output_dir, bigquery_destination=bigquery_destination, @@ -2848,10 +3181,11 @@ def run( self, dataset: datasets.TabularDataset, target_column: str, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, weight_column: Optional[str] = None, budget_milli_node_hours: int = 1000, model_display_name: Optional[str] = None, @@ -2864,12 +3198,25 @@ def run( ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (datasets.TabularDataset): @@ -2883,14 +3230,14 @@ def run( target_column (str): Required. The name of the column values of which the Model is to predict. training_fraction_split (float): - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -2901,6 +3248,16 @@ def run( ignored by the pipeline. Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -2992,6 +3349,7 @@ def run( validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, weight_column=weight_column, budget_milli_node_hours=budget_milli_node_hours, model_display_name=model_display_name, @@ -3008,10 +3366,11 @@ def _run( self, dataset: datasets.TabularDataset, target_column: str, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, weight_column: Optional[str] = None, budget_milli_node_hours: int = 1000, model_display_name: Optional[str] = None, @@ -3024,12 +3383,25 @@ def _run( ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (datasets.TabularDataset): @@ -3043,14 +3415,14 @@ def _run( target_column (str): Required. The name of the column values of which the Model is to predict. training_fraction_split (float): - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -3061,6 +3433,16 @@ def _run( ignored by the pipeline. Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -3200,6 +3582,7 @@ def _run( validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, model=model, ) @@ -3740,9 +4123,9 @@ def _run( training_task_definition=training_task_definition, training_task_inputs=training_task_inputs_dict, dataset=dataset, - training_fraction_split=0.8, - validation_fraction_split=0.1, - test_fraction_split=0.1, + training_fraction_split=None, + validation_fraction_split=None, + test_fraction_split=None, predefined_split_column_name=predefined_split_column_name, model=model, ) @@ -3929,9 +4312,12 @@ def __init__( def run( self, dataset: datasets.ImageDataset, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, budget_milli_node_hours: int = 1000, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, @@ -3940,12 +4326,24 @@ def run( ) -> models.Model: """Runs the AutoML Image training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.ImageDataset): @@ -3956,15 +4354,36 @@ def run( [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular Datasets, all their data is exported to training, to pick and choose from. - training_fraction_split: float = 0.8 - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. - validation_fraction_split: float = 0.1 - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. - test_fraction_split: float = 0.1 - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. budget_milli_node_hours: int = 1000 Optional. The train budget of creating this Model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. @@ -4026,6 +4445,9 @@ def run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, budget_milli_node_hours=budget_milli_node_hours, model_display_name=model_display_name, model_labels=model_labels, @@ -4038,9 +4460,12 @@ def _run( self, dataset: datasets.ImageDataset, base_model: Optional[models.Model] = None, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, budget_milli_node_hours: int = 1000, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, @@ -4049,12 +4474,24 @@ def _run( ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.ImageDataset): @@ -4072,14 +4509,35 @@ def _run( must be in the same Project and Location as the new Model to train, and have the same model_type. training_fraction_split (float): - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. budget_milli_node_hours (int): Optional. The train budget of creating this Model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. @@ -4162,6 +4620,9 @@ def _run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, model=model_tbt, ) @@ -4437,10 +4898,14 @@ def run( accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -4452,12 +4917,36 @@ def run( ie: replica_count = 10 will result in 1 chief and 9 workers All replicas have same machine_type, accelerator_type, and accelerator_count - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI.If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]): @@ -4574,14 +5063,35 @@ def run( Size in GB of the boot disk, default is 100GB. boot disk size must be within the range of [100, 64000]. training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -4591,6 +5101,15 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -4640,7 +5159,11 @@ def run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, bigquery_destination=bigquery_destination, tensorboard=tensorboard, sync=sync, @@ -4665,10 +5188,14 @@ def _run( base_output_dir: Optional[str] = None, service_account: Optional[str] = None, network: Optional[str] = None, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, + timestamp_split_column_name: Optional[str] = None, bigquery_destination: Optional[str] = None, tensorboard: Optional[str] = None, sync=True, @@ -4723,14 +5250,35 @@ def _run( Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. training_fraction_split (float): - The fraction of the input data that is to be - used to train the Model. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - The fraction of the input data that is to be - used to validate the Model. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - The fraction of the input data that is to be - used to evaluate the Model. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or @@ -4740,6 +5288,15 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. + timestamp_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + Supported only for tabular and time series Datasets. tensorboard (str): Optional. The name of a Vertex AI @@ -4800,7 +5357,11 @@ def _run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, model=managed_model, gcs_destination_uri_prefix=base_output_dir, bigquery_destination=bigquery_destination, @@ -4945,18 +5506,32 @@ def __init__( def run( self, dataset: datasets.VideoDataset, - training_fraction_split: float = 0.8, - test_fraction_split: float = 0.2, + training_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, sync: bool = True, ) -> models.Model: """Runs the AutoML Image training job and returns a model. - Data fraction splits: - ``training_fraction_split``, and ``test_fraction_split`` may optionally - be provided, they must sum to up to 1. If none of the fractions are set, - by default roughly 80% of data will be used for training, and 20% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.VideoDataset): @@ -4967,12 +5542,26 @@ def run( [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular Datasets, all their data is exported to training, to pick and choose from. - training_fraction_split: float = 0.8 - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. - test_fraction_split: float = 0.2 - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. model_display_name (str): Optional. The display name of the managed Vertex AI Model. The name can be up to 128 characters long and can be consist of any UTF-8 @@ -5014,6 +5603,8 @@ def run( dataset=dataset, training_fraction_split=training_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + test_filter_split=test_filter_split, model_display_name=model_display_name, model_labels=model_labels, sync=sync, @@ -5023,18 +5614,32 @@ def run( def _run( self, dataset: datasets.VideoDataset, - training_fraction_split: float = 0.8, - test_fraction_split: float = 0.2, + training_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, sync: bool = True, ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally - be provided, they must sum to up to 1. If none of the fractions are set, - by default roughly 80% of data will be used for training, and 20% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.VideoDataset): @@ -5046,11 +5651,25 @@ def _run( For tabular Datasets, all their data is exported to training, to pick and choose from. training_fraction_split (float): - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. model_display_name (str): Optional. The display name of the managed Vertex AI Model. The name can be up to 128 characters long and can be consist of any UTF-8 @@ -5091,13 +5710,22 @@ def _run( model_tbt.display_name = model_display_name or self._display_name model_tbt.labels = model_labels or self._labels + # AutoMLVideo does not support validation, so pass in '-' if any other filter split is provided. + validation_filter_split = ( + "-" + if all([training_filter_split is not None, test_filter_split is not None]) + else None + ) + return self._run_job( training_task_definition=training_task_definition, training_task_inputs=training_task_inputs_dict, dataset=dataset, training_fraction_split=training_fraction_split, - validation_fraction_split=0.0, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, model=model_tbt, ) @@ -5252,21 +5880,36 @@ def __init__( def run( self, dataset: datasets.TextDataset, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, sync: bool = True, ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.TextDataset): @@ -5275,15 +5918,36 @@ def run( and what is compatible should be described in the used TrainingPipeline's [training_task_definition] [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. - training_fraction_split: float = 0.8 - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. - validation_fraction_split: float = 0.1 - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. - test_fraction_split: float = 0.1 - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. model_display_name (str): Optional. The display name of the managed Vertex AI Model. The name can be up to 128 characters long and can consist @@ -5327,6 +5991,9 @@ def run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, model_display_name=model_display_name, model_labels=model_labels, sync=sync, @@ -5336,21 +6003,36 @@ def run( def _run( self, dataset: datasets.TextDataset, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, + training_filter_split: Optional[str] = None, + validation_filter_split: Optional[str] = None, + test_filter_split: Optional[str] = None, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, sync: bool = True, ) -> models.Model: """Runs the training job and returns a model. - Data fraction splits: - Any of ``training_fraction_split``, ``validation_fraction_split`` and - ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If - the provided ones sum to less than 1, the remainder is assigned to sets as - decided by Vertex AI. If none of the fractions are set, by default roughly 80% - of data will be used for training, 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Data filter splits: + Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and + ``test_filter_split`` must be provided. + Supported only for unstructured Datasets. Args: dataset (datasets.TextDataset): @@ -5362,14 +6044,35 @@ def _run( For Text Datasets, all their data is exported to training, to pick and choose from. training_fraction_split (float): - Required. The fraction of the input data that is to be - used to train the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. validation_fraction_split (float): - Required. The fraction of the input data that is to be - used to validate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. test_fraction_split (float): - Required. The fraction of the input data that is to be - used to evaluate the Model. This is ignored if Dataset is not provided. + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + training_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + validation_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + test_filter_split (str): + Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. model_display_name (str): Optional. If the script produces a managed Vertex AI Model. The display name of the Model. The name can be up to 128 characters long and can be consist @@ -5409,7 +6112,9 @@ def _run( training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, - predefined_split_column_name=None, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, model=model, ) diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index d699563327..8dc1f362ba 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -103,14 +103,11 @@ _TEST_DATASET_NAME = "test-dataset-name" _TEST_MODEL_DISPLAY_NAME = "model-display-name" + _TEST_LABELS = {"key": "value"} _TEST_MODEL_LABELS = {"model_key": "model_value"} -_TEST_TRAINING_FRACTION_SPLIT = 0.8 -_TEST_VALIDATION_FRACTION_SPLIT = 0.1 -_TEST_TEST_FRACTION_SPLIT = 0.1 -_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" -_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" _TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" @@ -261,18 +258,11 @@ def test_run_call_pipeline_service_create( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, predefined_split=gca_training_pipeline.PredefinedSplit( key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME ), @@ -348,19 +338,12 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - # Test that if defaults to the job display name true_managed_model = gca_model.Model( display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_time_series.name, ) @@ -422,17 +405,10 @@ def test_run_call_pipeline_if_set_additional_experiments( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - # Test that if defaults to the job display name true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_time_series.name, ) diff --git a/tests/unit/aiplatform/test_automl_image_training_jobs.py b/tests/unit/aiplatform/test_automl_image_training_jobs.py index a46f960b1c..7f092f12d1 100644 --- a/tests/unit/aiplatform/test_automl_image_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_image_training_jobs.py @@ -74,6 +74,10 @@ _TEST_FRACTION_SPLIT_VALIDATION = 0.2 _TEST_FRACTION_SPLIT_TEST = 0.2 +_TEST_FILTER_SPLIT_TRAINING = "train" +_TEST_FILTER_SPLIT_VALIDATION = "validate" +_TEST_FILTER_SPLIT_TEST = "test" + _TEST_MODEL_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" ) @@ -159,6 +163,7 @@ def mock_model_service_get(): def mock_dataset_image(): ds = mock.MagicMock(datasets.ImageDataset) ds.name = _TEST_DATASET_NAME + ds.metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_IMAGE ds._latest_future = None ds._exception = None ds._gca_resource = gca_dataset.Dataset( @@ -172,7 +177,7 @@ def mock_dataset_image(): @pytest.fixture -def mock_model_image(): +def mock_model(): model = mock.MagicMock(models.Model) model.name = _TEST_MODEL_ID model._latest_future = None @@ -193,7 +198,7 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def test_init_all_parameters(self, mock_model_image): + def test_init_all_parameters(self, mock_model): """Ensure all private members are set correctly at initialization""" aiplatform.init(project=_TEST_PROJECT) @@ -202,7 +207,7 @@ def test_init_all_parameters(self, mock_model_image): display_name=_TEST_DISPLAY_NAME, prediction_type=_TEST_PREDICTION_TYPE_ICN, model_type=_TEST_MODEL_TYPE_MOBILE, - base_model=mock_model_image, + base_model=mock_model, multi_label=True, ) @@ -210,9 +215,9 @@ def test_init_all_parameters(self, mock_model_image): assert job._model_type == _TEST_MODEL_TYPE_MOBILE assert job._prediction_type == _TEST_PREDICTION_TYPE_ICN assert job._multi_label is True - assert job._base_model == mock_model_image + assert job._base_model == mock_model - def test_init_wrong_parameters(self, mock_model_image): + def test_init_wrong_parameters(self, mock_model): """Ensure correct exceptions are raised when initializing with invalid args""" aiplatform.init(project=_TEST_PROJECT) @@ -233,7 +238,7 @@ def test_init_wrong_parameters(self, mock_model_image): training_jobs.AutoMLImageTrainingJob( display_name=_TEST_DISPLAY_NAME, prediction_type=_TEST_PREDICTION_TYPE_IOD, - base_model=mock_model_image, + base_model=mock_model, ) @pytest.mark.parametrize("sync", [True, False]) @@ -243,7 +248,7 @@ def test_run_call_pipeline_service_create( mock_pipeline_service_get, mock_dataset_image, mock_model_service_get, - mock_model_image, + mock_model, sync, ): """Create and run an AutoML ICN training job, verify calls and return value""" @@ -254,18 +259,16 @@ def test_run_call_pipeline_service_create( ) job = training_jobs.AutoMLImageTrainingJob( - display_name=_TEST_DISPLAY_NAME, - base_model=mock_model_image, - labels=_TEST_LABELS, + display_name=_TEST_DISPLAY_NAME, base_model=mock_model, labels=_TEST_LABELS, ) model_from_job = job.run( dataset=mock_dataset_image, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, sync=sync, @@ -274,21 +277,21 @@ def test_run_call_pipeline_service_create( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, ) true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, - labels=mock_model_image._gca_resource.labels, - description=mock_model_image._gca_resource.description, + labels=mock_model._gca_resource.labels, + description=mock_model._gca_resource.description, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + filter_split=true_filter_split, dataset_id=mock_dataset_image.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -333,9 +336,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( model_from_job = job.run( dataset=mock_dataset_image, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, ) @@ -343,12 +343,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, - ) - # Test that if defaults to the job display name true_managed_model = gca_model.Model( display_name=_TEST_DISPLAY_NAME, @@ -357,7 +351,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + dataset_id=mock_dataset_image.name ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -398,13 +392,38 @@ def test_run_called_twice_raises(self, mock_dataset_image, sync): with pytest.raises(RuntimeError): job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_split_raises( + self, mock_dataset_image, sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(ValueError): + model_from_job = job.run( dataset=mock_dataset_image, model_display_name=_TEST_MODEL_DISPLAY_NAME, training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) + if not sync: + model_from_job.wait() @pytest.mark.parametrize("sync", [True, False]) def test_run_raises_if_pipeline_fails( @@ -444,3 +463,226 @@ def test_raises_before_run_is_called(self, mock_pipeline_service_create): with pytest.raises(RuntimeError): job.state + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_fraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_filter( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + filter_split=true_filter_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_default( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + def test_splits_filter_incomplete( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_fraction_split=None, + test_filter_split=_TEST_FILTER_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + ) diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py index 2c380206e4..41614b738f 100644 --- a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -140,10 +140,12 @@ _TEST_LABELS = {"key": "value"} _TEST_MODEL_LABELS = {"model_key": "model_value"} -_TEST_TRAINING_FRACTION_SPLIT = 0.6 -_TEST_VALIDATION_FRACTION_SPLIT = 0.2 -_TEST_TEST_FRACTION_SPLIT = 0.2 -_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split" +_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp" _TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" @@ -325,10 +327,6 @@ def test_run_call_pipeline_service_create( target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -342,12 +340,6 @@ def test_run_call_pipeline_service_create( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS, @@ -355,10 +347,6 @@ def test_run_call_pipeline_service_create( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular.name, ) @@ -417,10 +405,6 @@ def test_run_call_pipeline_service_create_with_export_eval_data_items( dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -437,22 +421,12 @@ def test_run_call_pipeline_service_create_with_export_eval_data_items( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular.name, ) @@ -508,9 +482,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( model_from_job = job.run( dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -523,12 +494,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - # Test that if defaults to the job display name true_managed_model = gca_model.Model( display_name=_TEST_DISPLAY_NAME, @@ -537,7 +502,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name, + dataset_id=mock_dataset_tabular.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -584,10 +549,6 @@ def test_run_call_pipeline_service_create_if_no_column_transformations( dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -601,22 +562,12 @@ def test_run_call_pipeline_service_create_if_no_column_transformations( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular.name, ) @@ -665,10 +616,6 @@ def test_run_call_pipeline_service_create_if_set_additional_experiments( dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -682,22 +629,12 @@ def test_run_call_pipeline_service_create_if_set_additional_experiments( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular.name, ) @@ -746,10 +683,6 @@ def test_run_call_pipeline_service_create_with_column_specs( dataset=mock_dataset_tabular_alternative, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -759,19 +692,9 @@ def test_run_call_pipeline_service_create_with_column_specs( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular_alternative.name, ) @@ -858,10 +781,6 @@ def test_run_call_pipeline_service_create_with_column_specs_not_auto( dataset=mock_dataset_tabular_alternative, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, weight_column=_TEST_TRAINING_WEIGHT_COLUMN, budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, @@ -871,19 +790,9 @@ def test_run_call_pipeline_service_create_with_column_specs_not_auto( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_dataset_tabular_alternative.name, ) @@ -923,9 +832,6 @@ def test_run_called_twice_raises(self, mock_dataset_tabular, sync): dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -938,9 +844,6 @@ def test_run_called_twice_raises(self, mock_dataset_tabular, sync): dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -965,9 +868,6 @@ def test_run_raises_if_pipeline_fails( model_display_name=_TEST_MODEL_DISPLAY_NAME, dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -996,9 +896,6 @@ def test_wait_for_resource_creation_does_not_fail_if_creation_does_not_fail( model_display_name=_TEST_MODEL_DISPLAY_NAME, dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=False, ) @@ -1033,9 +930,6 @@ def test_create_fails(self, mock_dataset_tabular, sync): model_display_name=_TEST_MODEL_DISPLAY_NAME, dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) assert e.match("Mock fail") @@ -1065,9 +959,6 @@ def test_create_fails(self, mock_dataset_tabular, sync): model_display_name=_TEST_MODEL_DISPLAY_NAME, dataset=mock_dataset_tabular, target_column=_TEST_TRAINING_TARGET_COLUMN, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -1163,3 +1054,280 @@ def test_properties_throw_if_not_available(self): assert e.match( regexp=r"AutoMLTabularTrainingJob resource has not been created" ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_fraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_timestamp( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + timestamp_split_column_name=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_split = gca_training_pipeline.TimestampSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + key=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + timestamp_split=true_split, dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_predefined( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + predefined_split_column_name=_TEST_SPLIT_PREDEFINED_COLUMN_NAME, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_split = gca_training_pipeline.PredefinedSplit( + key=_TEST_SPLIT_PREDEFINED_COLUMN_NAME + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + predefined_split=true_split, dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_default( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) diff --git a/tests/unit/aiplatform/test_automl_text_training_jobs.py b/tests/unit/aiplatform/test_automl_text_training_jobs.py index 583789c00e..20220a1247 100644 --- a/tests/unit/aiplatform/test_automl_text_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_text_training_jobs.py @@ -59,6 +59,10 @@ _TEST_FRACTION_SPLIT_TRAINING = 0.6 _TEST_FRACTION_SPLIT_VALIDATION = 0.2 _TEST_FRACTION_SPLIT_TEST = 0.2 +_TEST_FILTER_SPLIT_TRAINING = "train" +_TEST_FILTER_SPLIT_VALIDATION = "validate" +_TEST_FILTER_SPLIT_TEST = "test" +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "predefined_column" _TEST_MODEL_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" @@ -145,6 +149,7 @@ def mock_model_service_get(): def mock_dataset_text(): ds = mock.MagicMock(datasets.TextDataset) ds.name = _TEST_DATASET_NAME + ds.metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_TEXT ds._latest_future = None ds._exception = None ds._gca_resource = gca_dataset.Dataset( @@ -270,28 +275,19 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job( model_from_job = job.run( dataset=mock_dataset_text, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, sync=sync, ) if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, - ) - true_managed_model = gca_model.Model( display_name=_TEST_MODEL_DISPLAY_NAME, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + dataset_id=mock_dataset_text.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -334,19 +330,19 @@ def test_run_call_pipeline_service_create_classification( dataset=mock_dataset_text, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, ) true_managed_model = gca_model.Model( @@ -356,7 +352,7 @@ def test_run_call_pipeline_service_create_classification( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + filter_split=true_filter_split, dataset_id=mock_dataset_text.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -472,19 +468,19 @@ def test_run_call_pipeline_service_create_sentiment( dataset=mock_dataset_text, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, ) true_managed_model = gca_model.Model( @@ -492,7 +488,7 @@ def test_run_call_pipeline_service_create_sentiment( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + filter_split=true_filter_split, dataset_id=mock_dataset_text.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -537,9 +533,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( model_from_job = job.run( dataset=mock_dataset_text, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, model_display_name=None, # Omit model_display_name sync=sync, ) @@ -547,19 +540,13 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, - test_fraction=_TEST_FRACTION_SPLIT_TEST, - ) - # Test that if defaults to the job display name true_managed_model = gca_model.Model( display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS, ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + dataset_id=mock_dataset_text.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -602,13 +589,42 @@ def test_run_called_twice_raises(self, mock_dataset_text, sync): with pytest.raises(RuntimeError): job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_split_raises( + self, mock_dataset_text, sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + multi_label=True, + ) + + with pytest.raises(ValueError): + model_from_job = job.run( dataset=mock_dataset_text, model_display_name=_TEST_MODEL_DISPLAY_NAME, training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) + if not sync: + model_from_job.wait() @pytest.mark.parametrize("sync", [True, False]) def test_run_raises_if_pipeline_fails( @@ -638,3 +654,198 @@ def test_run_raises_if_pipeline_fails( with pytest.raises(RuntimeError): job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_fraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_filter( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + validation_filter_split=_TEST_FILTER_SPLIT_VALIDATION, + test_filter_split=_TEST_FILTER_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + filter_split=true_filter_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_default( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) diff --git a/tests/unit/aiplatform/test_automl_video_training_jobs.py b/tests/unit/aiplatform/test_automl_video_training_jobs.py index fc7d6f38e3..7326050ae4 100644 --- a/tests/unit/aiplatform/test_automl_video_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_video_training_jobs.py @@ -54,7 +54,13 @@ ) _TEST_FRACTION_SPLIT_TRAINING = 0.8 +_TEST_FRACTION_SPLIT_VALIDATION = 0.0 _TEST_FRACTION_SPLIT_TEST = 0.2 +_TEST_ALTERNATE_FRACTION_SPLIT_TRAINING = 0.7 +_TEST_ALTERNATE_FRACTION_SPLIT_TEST = 0.3 +_TEST_FILTER_SPLIT_TRAINING = "train" +_TEST_FILTER_SPLIT_VALIDATION = "-" +_TEST_FILTER_SPLIT_TEST = "test" _TEST_MODEL_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" @@ -141,6 +147,7 @@ def mock_model_service_get(): def mock_dataset_video(): ds = mock.MagicMock(datasets.VideoDataset) ds.name = _TEST_DATASET_NAME + ds.metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_VIDEO ds._latest_future = None ds._exception = None ds._gca_resource = gca_dataset.Dataset( @@ -231,6 +238,72 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job( model_type=_TEST_MODEL_TYPE_CLOUD, ) + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_fraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + model_from_job = job.run( dataset=mock_dataset_video, model_display_name=_TEST_MODEL_DISPLAY_NAME, @@ -244,6 +317,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job( true_fraction_split = gca_training_pipeline.FractionSplit( training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, test_fraction=_TEST_FRACTION_SPLIT_TEST, ) @@ -271,12 +345,131 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job( training_pipeline=true_training_pipeline, ) - mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) - assert job._gca_resource is mock_pipeline_service_get.return_value - assert model_from_job._gca_resource is mock_model_service_get.return_value - assert job.get_model()._gca_resource is mock_model_service_get.return_value - assert not job.has_failed - assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_filter( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + test_filter_split=_TEST_FILTER_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + filter_split=true_filter_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_default( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create( @@ -305,17 +498,18 @@ def test_run_call_pipeline_service_create( dataset=mock_dataset_video, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TRAINING, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) if not sync: model_from_job.wait() - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - test_fraction=_TEST_FRACTION_SPLIT_TEST, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_FILTER_SPLIT_TRAINING, + validation_filter=_TEST_FILTER_SPLIT_VALIDATION, + test_filter=_TEST_FILTER_SPLIT_TEST, ) true_managed_model = gca_model.Model( @@ -326,7 +520,7 @@ def test_run_call_pipeline_service_create( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + filter_split=true_filter_split, dataset_id=mock_dataset_video.name, ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( @@ -371,16 +565,17 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( model_from_job = job.run( dataset=mock_dataset_video, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_fraction_split=_TEST_ALTERNATE_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_ALTERNATE_FRACTION_SPLIT_TEST, ) if not sync: model_from_job.wait() true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_FRACTION_SPLIT_TRAINING, - test_fraction=_TEST_FRACTION_SPLIT_TEST, + training_fraction=_TEST_ALTERNATE_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_ALTERNATE_FRACTION_SPLIT_TEST, ) # Test that if defaults to the job display name @@ -422,19 +617,41 @@ def test_run_called_twice_raises( job.run( dataset=mock_dataset_video, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, - test_fraction_split=_TEST_FRACTION_SPLIT_TEST, sync=sync, ) with pytest.raises(RuntimeError): job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_split_raises( + self, mock_dataset_video, sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(ValueError): + model_from_job = job.run( dataset=mock_dataset_video, model_display_name=_TEST_MODEL_DISPLAY_NAME, training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + training_filter_split=_TEST_FILTER_SPLIT_TEST, + test_filter_split=_TEST_FILTER_SPLIT_TEST, sync=sync, ) + if not sync: + model_from_job.wait() @pytest.mark.parametrize("sync", [True, False]) def test_run_raises_if_pipeline_fails( diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 3e694e6a1e..0fd781b380 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -103,13 +103,14 @@ _TEST_LABELS = {"key": "value"} _TEST_MODEL_LABELS = {"model_key": "model_value"} -_TEST_DEFAULT_TRAINING_FRACTION_SPLIT = 0.8 -_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT = 0.1 -_TEST_DEFAULT_TEST_FRACTION_SPLIT = 0.1 _TEST_TRAINING_FRACTION_SPLIT = 0.6 _TEST_VALIDATION_FRACTION_SPLIT = 0.2 _TEST_TEST_FRACTION_SPLIT = 0.2 +_TEST_TRAINING_FILTER_SPLIT = "train" +_TEST_VALIDATION_FILTER_SPLIT = "validate" +_TEST_TEST_FILTER_SPLIT = "test" _TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" +_TEST_TIMESTAMP_SPLIT_COLUMN_NAME = "timestamp" _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" @@ -579,6 +580,7 @@ def mock_python_package_to_gcs(): def mock_tabular_dataset(): ds = mock.MagicMock(datasets.TabularDataset) ds.name = _TEST_DATASET_NAME + ds.metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_TABULAR ds._latest_future = None ds._exception = None ds._gca_resource = gca_dataset.Dataset( @@ -595,6 +597,7 @@ def mock_tabular_dataset(): def mock_nontabular_dataset(): ds = mock.MagicMock(datasets.ImageDataset) ds.name = _TEST_DATASET_NAME + ds.metadata_schema_uri = _TEST_METADATA_SCHEMA_URI_NONTABULAR ds._latest_future = None ds._exception = None ds._gca_resource = gca_dataset.Dataset( @@ -668,7 +671,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME, sync=sync, ) @@ -708,10 +711,11 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( + true_timestamp_split = gca_training_pipeline.TimestampSplit( training_fraction=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction=_TEST_TEST_FRACTION_SPLIT, + key=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, ) env = [ @@ -748,10 +752,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), + timestamp_split=true_timestamp_split, dataset_id=mock_tabular_dataset.name, gcs_destination=gca_io.GcsDestination( output_uri_prefix=_TEST_BASE_OUTPUT_DIR @@ -843,9 +844,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, sync=sync, ) @@ -879,12 +877,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - env = [ gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() @@ -918,7 +910,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, predefined_split=gca_training_pipeline.PredefinedSplit( key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME ), @@ -1049,6 +1040,34 @@ def test_run_with_invalid_accelerator_type_raises( accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_splits_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, @@ -1123,6 +1142,9 @@ def test_run_call_pipeline_service_create_with_no_dataset( training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + training_filter_split=_TEST_TRAINING_FILTER_SPLIT, + validation_filter_split=_TEST_VALIDATION_FILTER_SPLIT, + test_filter_split=_TEST_TEST_FILTER_SPLIT, sync=sync, ) @@ -1379,9 +1401,6 @@ def test_run_call_pipeline_service_create_distributed_training( accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -1441,12 +1460,6 @@ def test_run_call_pipeline_service_create_distributed_training( }, ] - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - true_container_spec = gca_model.ModelContainerSpec( image_uri=_TEST_SERVING_CONTAINER_IMAGE, predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, @@ -1464,7 +1477,6 @@ def test_run_call_pipeline_service_create_distributed_training( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, dataset_id=mock_tabular_dataset.name, gcs_destination=gca_io.GcsDestination( output_uri_prefix=_TEST_BASE_OUTPUT_DIR @@ -1660,6 +1672,9 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, + training_filter_split=_TEST_TRAINING_FILTER_SPLIT, + validation_filter_split=_TEST_VALIDATION_FILTER_SPLIT, + test_filter_split=_TEST_TEST_FILTER_SPLIT, sync=sync, ) @@ -1693,10 +1708,10 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_TRAINING_FILTER_SPLIT, + validation_filter=_TEST_VALIDATION_FILTER_SPLIT, + test_filter=_TEST_TEST_FILTER_SPLIT, ) env = [ @@ -1732,7 +1747,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, + filter_split=true_filter_split, dataset_id=mock_nontabular_dataset.name, annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, gcs_destination=gca_io.GcsDestination( @@ -1909,9 +1924,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, service_account=_TEST_SERVICE_ACCOUNT, tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME, @@ -1946,12 +1958,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - env = [ gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() @@ -1986,7 +1992,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, predefined_split=gca_training_pipeline.PredefinedSplit( key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME ), @@ -2079,7 +2084,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, sync=sync, ) @@ -2106,10 +2111,11 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( + true_timestamp_split = gca_training_pipeline.TimestampSplit( training_fraction=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction=_TEST_TEST_FRACTION_SPLIT, + key=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, ) env = [ @@ -2145,10 +2151,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), + timestamp_split=true_timestamp_split, dataset_id=mock_tabular_dataset.name, bigquery_destination=gca_io.BigQueryDestination( output_uri=_TEST_BIGQUERY_DESTINATION @@ -2276,6 +2279,33 @@ def test_run_with_invalid_accelerator_type_raises( accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_split_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, @@ -2432,9 +2462,6 @@ def test_run_returns_none_if_no_model_to_upload( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -2734,6 +2761,9 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, + training_filter_split=_TEST_TRAINING_FILTER_SPLIT, + validation_filter_split=_TEST_VALIDATION_FILTER_SPLIT, + test_filter_split=_TEST_TEST_FILTER_SPLIT, sync=sync, ) @@ -2760,10 +2790,10 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_TRAINING_FILTER_SPLIT, + validation_filter=_TEST_VALIDATION_FILTER_SPLIT, + test_filter=_TEST_TEST_FILTER_SPLIT, ) env = [ @@ -2799,7 +2829,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, + filter_split=true_filter_split, dataset_id=mock_nontabular_dataset.name, annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, gcs_destination=gca_io.GcsDestination( @@ -3257,7 +3287,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, sync=sync, ) @@ -3331,9 +3360,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( true_input_data_config = gca_training_pipeline.InputDataConfig( fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), dataset_id=mock_tabular_dataset.name, gcs_destination=gca_io.GcsDestination( output_uri_prefix=_TEST_BASE_OUTPUT_DIR @@ -3421,9 +3447,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, sync=sync, ) @@ -3452,12 +3475,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_TEST_FRACTION_SPLIT, - ) - env = [ gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() @@ -3492,7 +3509,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, predefined_split=gca_training_pipeline.PredefinedSplit( key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME ), @@ -3582,7 +3598,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=_TEST_TEST_FRACTION_SPLIT, - predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, sync=sync, ) @@ -3610,10 +3626,11 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( }, } - true_fraction_split = gca_training_pipeline.FractionSplit( + true_timestamp_split = gca_training_pipeline.TimestampSplit( training_fraction=_TEST_TRAINING_FRACTION_SPLIT, validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, test_fraction=_TEST_TEST_FRACTION_SPLIT, + key=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME, ) env = [ @@ -3649,10 +3666,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, - predefined_split=gca_training_pipeline.PredefinedSplit( - key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME - ), + timestamp_split=true_timestamp_split, dataset_id=mock_tabular_dataset.name, bigquery_destination=gca_io.BigQueryDestination( output_uri=_TEST_BIGQUERY_DESTINATION @@ -3726,9 +3740,6 @@ def test_run_called_twice_raises( accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -3742,9 +3753,6 @@ def test_run_called_twice_raises( accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, model_display_name=_TEST_MODEL_DISPLAY_NAME, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -3788,6 +3796,38 @@ def test_run_with_invalid_accelerator_type_raises( sync=sync, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_two_split_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + @pytest.mark.parametrize("sync", [True, False]) def test_run_with_incomplete_model_info_raises_with_model_to_upload( self, @@ -4013,9 +4053,6 @@ def test_run_raises_if_pipeline_fails( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, - training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, - validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, - test_fraction_split=_TEST_TEST_FRACTION_SPLIT, sync=sync, ) @@ -4250,6 +4287,9 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ accelerator_count=_TEST_ACCELERATOR_COUNT, service_account=_TEST_SERVICE_ACCOUNT, tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME, + training_filter_split=_TEST_TRAINING_FILTER_SPLIT, + validation_filter_split=_TEST_VALIDATION_FILTER_SPLIT, + test_filter_split=_TEST_TEST_FILTER_SPLIT, sync=sync, ) @@ -4277,10 +4317,10 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ }, } - true_fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, - validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, - test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + true_filter_split = gca_training_pipeline.FilterSplit( + training_filter=_TEST_TRAINING_FILTER_SPLIT, + validation_filter=_TEST_VALIDATION_FILTER_SPLIT, + test_filter=_TEST_TEST_FILTER_SPLIT, ) env = [ @@ -4316,7 +4356,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_ ) true_input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=true_fraction_split, + filter_split=true_filter_split, dataset_id=mock_nontabular_dataset.name, annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, gcs_destination=gca_io.GcsDestination(