From 1f0d5f3e3f95ee5056545e9d4742b96e9380a22e Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Sun, 11 Apr 2021 14:28:02 -0600 Subject: [PATCH] feat: Make aiplatform.Dataset private (#296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Removes all use of `aiplatform.Dataset` - Replaces internal references with `aiplatform._Dataset` - Replaces public references with specific dataset sub-class or TypeVar `datasets.SupportedDatasets` ## TODO after merging - [x] @sasha-gitg PTAL [at this list of updates](https://docs.google.com/document/d/1mKqldyCZaLgtYrgxpd6F3TJ35Ap3l6MhOxyiWTmZ4NQ/edit) to external notebooks and Colabs to no longer reference `aiplatform.Dataset` Fixes [b/184156060](http://b/184156060) 🦕 --- google/cloud/aiplatform/__init__.py | 4 +- google/cloud/aiplatform/datasets/__init__.py | 5 +- google/cloud/aiplatform/datasets/dataset.py | 8 +- .../aiplatform/datasets/image_dataset.py | 2 +- .../aiplatform/datasets/tabular_dataset.py | 2 +- .../cloud/aiplatform/datasets/text_dataset.py | 2 +- .../aiplatform/datasets/video_dataset.py | 2 +- google/cloud/aiplatform/training_jobs.py | 156 ++++++++++++++---- tests/system/aiplatform/test_dataset.py | 22 ++- .../test_automl_image_training_jobs.py | 2 +- .../test_automl_tabular_training_jobs.py | 4 +- .../test_automl_text_training_jobs.py | 2 +- .../test_automl_video_training_jobs.py | 2 +- tests/unit/aiplatform/test_datasets.py | 32 ++-- tests/unit/aiplatform/test_end_to_end.py | 9 +- tests/unit/aiplatform/test_training_jobs.py | 4 +- 16 files changed, 168 insertions(+), 90 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 9c94090548..58eb824454 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -20,9 +20,8 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform.datasets import ( - Dataset, - TabularDataset, ImageDataset, + TabularDataset, TextDataset, VideoDataset, ) @@ -59,7 +58,6 @@ "CustomTrainingJob", "CustomContainerTrainingJob", "CustomPythonPackageTrainingJob", - "Dataset", "Endpoint", "ImageDataset", "Model", diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py index cd4e936f78..57e2bad45d 100644 --- a/google/cloud/aiplatform/datasets/__init__.py +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -15,14 +15,15 @@ # limitations under the License. # -from google.cloud.aiplatform.datasets.dataset import Dataset +from google.cloud.aiplatform.datasets.dataset import _Dataset from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset from google.cloud.aiplatform.datasets.image_dataset import ImageDataset from google.cloud.aiplatform.datasets.text_dataset import TextDataset from google.cloud.aiplatform.datasets.video_dataset import VideoDataset + __all__ = ( - "Dataset", + "_Dataset", "TabularDataset", "ImageDataset", "TextDataset", diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 999b80c5e2..ffc1d6790e 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -33,7 +33,7 @@ from google.cloud.aiplatform.datasets import _datasources -class Dataset(base.AiPlatformResourceNounWithFutureManager): +class _Dataset(base.AiPlatformResourceNounWithFutureManager): """Managed dataset resource for AI Platform""" client_class = utils.DatasetClientWithOverride @@ -115,7 +115,7 @@ def create( request_metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, sync: bool = True, - ) -> "Dataset": + ) -> "_Dataset": """Creates a new dataset and optionally imports data into dataset when source and import_schema_uri are passed. @@ -241,7 +241,7 @@ def _create_and_import( request_metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, sync: bool = True, - ) -> "Dataset": + ) -> "_Dataset": """Creates a new dataset and optionally imports data into dataset when source and import_schema_uri are passed. @@ -400,7 +400,7 @@ def import_data( import_schema_uri: str, data_item_labels: Optional[Dict] = None, sync: bool = True, - ) -> "Dataset": + ) -> "_Dataset": """Upload data to existing managed dataset. Args: diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py index a2408d79f8..cea13014d8 100644 --- a/google/cloud/aiplatform/datasets/image_dataset.py +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -26,7 +26,7 @@ from google.cloud.aiplatform import utils -class ImageDataset(datasets.Dataset): +class ImageDataset(datasets._Dataset): """Managed image dataset resource for AI Platform""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 52cc877f79..3dd217aad7 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -26,7 +26,7 @@ from google.cloud.aiplatform import utils -class TabularDataset(datasets.Dataset): +class TabularDataset(datasets._Dataset): """Managed tabular dataset resource for AI Platform""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py index fe997f95a1..2b791e5c82 100644 --- a/google/cloud/aiplatform/datasets/text_dataset.py +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -26,7 +26,7 @@ from google.cloud.aiplatform import utils -class TextDataset(datasets.Dataset): +class TextDataset(datasets._Dataset): """Managed text dataset resource for AI Platform""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py index 09afc3f3d5..c50298f99a 100644 --- a/google/cloud/aiplatform/datasets/video_dataset.py +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -26,7 +26,7 @@ from google.cloud.aiplatform import utils -class VideoDataset(datasets.Dataset): +class VideoDataset(datasets._Dataset): """Managed video dataset resource for AI Platform""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 02264f9244..8c3d933326 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -216,7 +216,7 @@ def run(self) -> Optional[models.Model]: @staticmethod def _create_input_data_config( - dataset: Optional[datasets.Dataset] = None, + dataset: Optional[datasets._Dataset] = None, annotation_schema_uri: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -228,7 +228,7 @@ def _create_input_data_config( """Constructs a input data config to pass to the training pipeline. Args: - dataset (datasets.Dataset): + dataset (datasets._Dataset): The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -362,7 +362,7 @@ def _run_job( self, training_task_definition: str, training_task_inputs: Union[dict, proto.Message], - dataset: Optional[datasets.Dataset], + dataset: Optional[datasets._Dataset], training_fraction_split: float, validation_fraction_split: float, test_fraction_split: float, @@ -390,7 +390,7 @@ def _run_job( read access. training_task_inputs (Union[dict, proto.Message]): Required. The training task's input that corresponds to the training_task_definition parameter. - dataset (datasets.Dataset): + dataset (datasets._Dataset): The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -1590,7 +1590,7 @@ def __init__( Usage with Dataset: - ds = aiplatform.Dataset( + ds = aiplatform.TabularDataset( 'projects/my-project/locations/us-central1/datasets/12345') job.run(ds, replica_count=1, model_display_name='my-trained-model') @@ -1765,7 +1765,14 @@ def __init__( # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit def run( self, - dataset: Optional[datasets.Dataset] = None, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, @@ -1797,9 +1804,16 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. Custom training script should - retrieve datasets through passed in environement variables uris: + retrieve datasets through passed in environment variables uris: os.environ["AIP_TRAINING_DATA_URI"] os.environ["AIP_VALIDATION_DATA_URI"] @@ -1927,7 +1941,14 @@ def run( def _run( self, python_packager: _TrainingScriptPythonPackager, - dataset: Optional[datasets.Dataset], + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], annotation_schema_uri: Optional[str], worker_pool_specs: _DistributedTrainingSpec, managed_model: Optional[gca_model.Model] = None, @@ -1945,7 +1966,14 @@ def _run( Args: python_packager (_TrainingScriptPythonPackager): Required. Python Packager pointing to training script locally. - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. annotation_schema_uri (str): Google Cloud Storage URI points to a YAML file describing @@ -2080,7 +2108,7 @@ def __init__( Usage with Dataset: - ds = aiplatform.Dataset( + ds = aiplatform.TabularDataset( 'projects/my-project/locations/us-central1/datasets/12345') job.run(ds, replica_count=1, model_display_name='my-trained-model') @@ -2254,7 +2282,14 @@ def __init__( # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit def run( self, - dataset: Optional[datasets.Dataset] = None, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, @@ -2286,7 +2321,14 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environment variables uris: @@ -2414,7 +2456,14 @@ def run( @base.optional_sync(construct_object_on_arg="managed_model") def _run( self, - dataset: Optional[datasets.Dataset], + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], annotation_schema_uri: Optional[str], worker_pool_specs: _DistributedTrainingSpec, managed_model: Optional[gca_model.Model] = None, @@ -2429,7 +2478,14 @@ def _run( ) -> Optional[models.Model]: """Packages local script and launches training_job. Args: - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. annotation_schema_uri (str): Google Cloud Storage URI points to a YAML file describing @@ -2649,7 +2705,7 @@ def __init__( def run( self, - dataset: datasets.Dataset, + dataset: datasets.TabularDataset, target_column: str, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2671,7 +2727,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.TabularDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -2764,7 +2820,7 @@ def run( @base.optional_sync() def _run( self, - dataset: datasets.Dataset, + dataset: datasets.TabularDataset, target_column: str, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2786,7 +2842,7 @@ def _run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.TabularDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -3050,7 +3106,7 @@ def __init__( def run( self, - dataset: datasets.Dataset, + dataset: datasets.ImageDataset, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -3069,7 +3125,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.ImageDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -3141,7 +3197,7 @@ def run( @base.optional_sync() def _run( self, - dataset: datasets.Dataset, + dataset: datasets.ImageDataset, base_model: Optional[models.Model] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -3161,7 +3217,7 @@ def _run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.ImageDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -3312,7 +3368,7 @@ def __init__( Usage with Dataset: - ds = aiplatform.Dataset( + ds = aiplatform.TabularDataset( 'projects/my-project/locations/us-central1/datasets/12345' ) @@ -3491,7 +3547,14 @@ def __init__( def run( self, - dataset: Optional[datasets.Dataset] = None, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, @@ -3523,7 +3586,14 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environement variables uris: @@ -3646,7 +3716,14 @@ def run( @base.optional_sync(construct_object_on_arg="managed_model") def _run( self, - dataset: Optional[datasets.Dataset], + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], annotation_schema_uri: Optional[str], worker_pool_specs: _DistributedTrainingSpec, managed_model: Optional[gca_model.Model] = None, @@ -3662,7 +3739,14 @@ def _run( """Packages local script and launches training_job. Args: - dataset (datasets.Dataset): + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): AI Platform to fit this training against. annotation_schema_uri (str): Google Cloud Storage URI points to a YAML file describing @@ -3861,7 +3945,7 @@ def __init__( def run( self, - dataset: datasets.Dataset, + dataset: datasets.VideoDataset, training_fraction_split: float = 0.8, test_fraction_split: float = 0.2, model_display_name: Optional[str] = None, @@ -3875,7 +3959,7 @@ def run( by default roughly 80% of data will be used for training, and 20% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.VideoDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -3922,7 +4006,7 @@ def run( @base.optional_sync() def _run( self, - dataset: datasets.Dataset, + dataset: datasets.VideoDataset, training_fraction_split: float = 0.8, test_fraction_split: float = 0.2, model_display_name: Optional[str] = None, @@ -3936,7 +4020,7 @@ def _run( by default roughly 80% of data will be used for training, and 20% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.VideoDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -4127,7 +4211,7 @@ def __init__( def run( self, - dataset: datasets.Dataset, + dataset: datasets.TextDataset, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -4144,7 +4228,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.TextDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used @@ -4194,7 +4278,7 @@ def run( @base.optional_sync() def _run( self, - dataset: datasets.Dataset, + dataset: datasets.TextDataset, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -4211,7 +4295,7 @@ def _run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset (datasets.Dataset): + dataset (datasets.TextDataset): Required. The dataset within the same Project from which data will be used to train the Model. The Dataset must use schema compatible with Model being trained, and what is compatible should be described in the used diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index e3b7d08874..e18390a76a 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -29,7 +29,7 @@ from google.cloud import aiplatform from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer -from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.services import dataset_service from test_utils.vpcsc_config import vpcsc_config @@ -101,7 +101,7 @@ def dataset_gapic_client(self): @pytest.fixture() def create_text_dataset(self, dataset_gapic_client, shared_state): - gapic_dataset = dataset.Dataset( + gapic_dataset = gca_dataset.Dataset( display_name=f"temp_sdk_integration_test_create_text_dataset_{uuid.uuid4()}", metadata_schema_uri=aiplatform.schema.dataset.metadata.text, ) @@ -116,7 +116,7 @@ def create_text_dataset(self, dataset_gapic_client, shared_state): @pytest.fixture() def create_tabular_dataset(self, dataset_gapic_client, shared_state): - gapic_dataset = dataset.Dataset( + gapic_dataset = gca_dataset.Dataset( display_name=f"temp_sdk_integration_test_create_tabular_dataset_{uuid.uuid4()}", metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, ) @@ -131,7 +131,7 @@ def create_tabular_dataset(self, dataset_gapic_client, shared_state): @pytest.fixture() def create_image_dataset(self, dataset_gapic_client, shared_state): - gapic_dataset = dataset.Dataset( + gapic_dataset = gca_dataset.Dataset( display_name=f"temp_sdk_integration_test_create_image_dataset_{uuid.uuid4()}", metadata_schema_uri=aiplatform.schema.dataset.metadata.image, ) @@ -163,7 +163,7 @@ def test_get_existing_dataset(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - flowers_dataset = aiplatform.Dataset(dataset_name=_TEST_IMAGE_DATASET_ID) + flowers_dataset = aiplatform.ImageDataset(dataset_name=_TEST_IMAGE_DATASET_ID) assert flowers_dataset.name == _TEST_IMAGE_DATASET_ID assert flowers_dataset.display_name == _TEST_DATASET_DISPLAY_NAME @@ -175,7 +175,7 @@ def test_get_nonexistent_dataset(self): # AI Platform service returns 404 with pytest.raises(exceptions.NotFound): - aiplatform.Dataset(dataset_name="0") + aiplatform.ImageDataset(dataset_name="0") @pytest.mark.usefixtures("create_text_dataset", "delete_new_dataset") def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state): @@ -185,7 +185,7 @@ def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state): assert shared_state["dataset_name"] aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - my_dataset = aiplatform.Dataset(dataset_name=shared_state["dataset_name"]) + my_dataset = aiplatform.TextDataset(dataset_name=shared_state["dataset_name"]) data_items_pre_import = dataset_gapic_client.list_data_items( parent=my_dataset.resource_name @@ -213,9 +213,8 @@ def test_create_and_import_image_dataset(self, dataset_gapic_client, shared_stat aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - img_dataset = aiplatform.Dataset.create( + img_dataset = aiplatform.ImageDataset.create( display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", - metadata_schema_uri=aiplatform.schema.dataset.metadata.image, gcs_source=_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE, import_schema_uri=_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA, ) @@ -235,9 +234,8 @@ def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - tabular_dataset = aiplatform.Dataset.create( + tabular_dataset = aiplatform.TabularDataset.create( display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", - metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, gcs_source=[_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE], ) @@ -270,7 +268,7 @@ def test_export_data(self, shared_state): staging_bucket=f"gs://{shared_state['staging_bucket']}", ) - text_dataset = aiplatform.Dataset(dataset_name=_TEST_TEXT_DATASET_ID) + text_dataset = aiplatform.TextDataset(dataset_name=_TEST_TEXT_DATASET_ID) exported_files = text_dataset.export_data( output_dir=f"gs://{shared_state['staging_bucket']}" diff --git a/tests/unit/aiplatform/test_automl_image_training_jobs.py b/tests/unit/aiplatform/test_automl_image_training_jobs.py index d85f4f3b97..039899b3ff 100644 --- a/tests/unit/aiplatform/test_automl_image_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_image_training_jobs.py @@ -141,7 +141,7 @@ def mock_model_service_get(): @pytest.fixture def mock_dataset_image(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.ImageDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py index 77435fa7ed..7499631d61 100644 --- a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -145,7 +145,7 @@ def mock_model_service_get(): @pytest.fixture def mock_dataset_tabular(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.TabularDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( @@ -160,7 +160,7 @@ def mock_dataset_tabular(): @pytest.fixture def mock_dataset_nontabular(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.ImageDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( diff --git a/tests/unit/aiplatform/test_automl_text_training_jobs.py b/tests/unit/aiplatform/test_automl_text_training_jobs.py index 84726aa6fa..0146a8815e 100644 --- a/tests/unit/aiplatform/test_automl_text_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_text_training_jobs.py @@ -126,7 +126,7 @@ def mock_model_service_get(): @pytest.fixture def mock_dataset_text(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.TextDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( diff --git a/tests/unit/aiplatform/test_automl_video_training_jobs.py b/tests/unit/aiplatform/test_automl_video_training_jobs.py index 8743c00c9d..ef91ed0dbf 100644 --- a/tests/unit/aiplatform/test_automl_video_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_video_training_jobs.py @@ -122,7 +122,7 @@ def mock_model_service_get(): @pytest.fixture def mock_dataset_video(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.VideoDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index f49c38e62f..5d1f92fe79 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -287,28 +287,28 @@ def teardown_method(self): def test_init_dataset(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT) - datasets.Dataset(dataset_name=_TEST_NAME) + datasets._Dataset(dataset_name=_TEST_NAME) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) def test_init_dataset_with_id_only_with_project_and_location( self, get_dataset_mock ): aiplatform.init(project=_TEST_PROJECT) - datasets.Dataset( + datasets._Dataset( dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION ) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) def test_init_dataset_with_project_and_location(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT) - datasets.Dataset( + datasets._Dataset( dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION ) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT) - datasets.Dataset( + datasets._Dataset( dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION ) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) @@ -316,7 +316,7 @@ def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock): def test_init_dataset_with_project_and_alt_location(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(RuntimeError): - datasets.Dataset( + datasets._Dataset( dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_ALT_LOCATION, @@ -324,7 +324,7 @@ def test_init_dataset_with_project_and_alt_location(self): def test_init_dataset_with_id_only(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - datasets.Dataset(dataset_name=_TEST_ID) + datasets._Dataset(dataset_name=_TEST_ID) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) @pytest.mark.usefixtures("get_dataset_without_name_mock") @@ -333,21 +333,21 @@ def test_init_dataset_with_id_only(self, get_dataset_mock): ) def test_init_dataset_with_id_only_without_project_or_location(self): with pytest.raises(GoogleAuthError): - datasets.Dataset( + datasets._Dataset( dataset_name=_TEST_ID, credentials=auth_credentials.AnonymousCredentials(), ) def test_init_dataset_with_location_override(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - datasets.Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION) + datasets._Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION) get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME) @pytest.mark.usefixtures("get_dataset_mock") def test_init_dataset_with_invalid_name(self): with pytest.raises(ValueError): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - datasets.Dataset(dataset_name=_TEST_INVALID_NAME) + datasets._Dataset(dataset_name=_TEST_INVALID_NAME) @pytest.mark.usefixtures("get_dataset_mock") @pytest.mark.parametrize("sync", [True, False]) @@ -358,7 +358,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_dataset( project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) - my_dataset = datasets.Dataset.create( + my_dataset = datasets._Dataset.create( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, sync=sync, @@ -385,7 +385,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_dataset( def test_create_dataset_nontabular(self, create_dataset_mock, sync): aiplatform.init(project=_TEST_PROJECT) - my_dataset = datasets.Dataset.create( + my_dataset = datasets._Dataset.create( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, @@ -412,7 +412,7 @@ def test_create_dataset_nontabular(self, create_dataset_mock, sync): def test_create_dataset_tabular(self, create_dataset_mock): aiplatform.init(project=_TEST_PROJECT) - datasets.Dataset.create( + datasets._Dataset.create( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, bq_source=_TEST_SOURCE_URI_BQ, @@ -439,7 +439,7 @@ def test_create_and_import_dataset( ): aiplatform.init(project=_TEST_PROJECT) - my_dataset = datasets.Dataset.create( + my_dataset = datasets._Dataset.create( display_name=_TEST_DISPLAY_NAME, gcs_source=_TEST_SOURCE_URI_GCS, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, @@ -483,7 +483,7 @@ def test_create_and_import_dataset( def test_import_data(self, import_data_mock, sync): aiplatform.init(project=_TEST_PROJECT) - my_dataset = datasets.Dataset(dataset_name=_TEST_NAME) + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) my_dataset.import_data( gcs_source=_TEST_SOURCE_URI_GCS, @@ -509,7 +509,7 @@ def test_import_data(self, import_data_mock, sync): def test_export_data(self, export_data_mock): aiplatform.init(project=_TEST_PROJECT) - my_dataset = datasets.Dataset(dataset_name=_TEST_NAME) + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) my_dataset.export_data(output_dir=_TEST_OUTPUT_DIR) @@ -528,7 +528,7 @@ def test_create_then_import( aiplatform.init(project=_TEST_PROJECT) - my_dataset = datasets.Dataset.create( + my_dataset = datasets._Dataset.create( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py index 4937c95e34..6288628d4c 100644 --- a/tests/unit/aiplatform/test_end_to_end.py +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -96,9 +96,8 @@ def test_dataset_create_to_model_predict( credentials=test_training_jobs._TEST_CREDENTIALS, ) - my_dataset = aiplatform.Dataset.create( + my_dataset = aiplatform.ImageDataset.create( display_name=test_datasets._TEST_DISPLAY_NAME, - metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync, ) @@ -301,10 +300,8 @@ def test_dataset_create_to_model_predict_with_pipeline_fail( encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) - my_dataset = aiplatform.Dataset.create( - display_name=test_datasets._TEST_DISPLAY_NAME, - metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, - sync=sync, + my_dataset = aiplatform.ImageDataset.create( + display_name=test_datasets._TEST_DISPLAY_NAME, sync=sync, ) my_dataset.import_data( diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 07585d7c3a..2081f90df8 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -490,7 +490,7 @@ def mock_python_package_to_gcs(): @pytest.fixture def mock_tabular_dataset(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.TabularDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset( @@ -505,7 +505,7 @@ def mock_tabular_dataset(): @pytest.fixture def mock_nontabular_dataset(): - ds = mock.MagicMock(datasets.Dataset) + ds = mock.MagicMock(datasets.ImageDataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None ds._gca_resource = gca_dataset.Dataset(