From 4fe4558ea0aaf73e3c0e9715ae90cb729a4c5678 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Thu, 5 May 2022 12:21:55 -0400 Subject: [PATCH] feat: add Pandas DataFrame support to TabularDataset (#1185) * add create_from_dataframe method * add tests for create_from_dataframe * update docstrings and run linter * update docstrings and make display_name optional * updates from sashas feedback: added integration test, update validations * remove some logging * update error handling on bq_schema arg * updates from sashas feedback * update bq_schema docstring --- .../aiplatform/datasets/tabular_dataset.py | 114 ++++++++++- setup.py | 4 + tests/system/aiplatform/test_dataset.py | 165 ++++++++++++++-- tests/unit/aiplatform/test_datasets.py | 177 ++++++++++++++++++ 4 files changed, 439 insertions(+), 21 deletions(-) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index ec9769bb7f..732cebe26f 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +19,18 @@ from google.auth import credentials as auth_credentials +from google.cloud import bigquery +from google.cloud.aiplatform import base from google.cloud.aiplatform import datasets from google.cloud.aiplatform.datasets import _datasources from google.cloud.aiplatform import initializer from google.cloud.aiplatform import schema from google.cloud.aiplatform import utils +_AUTOML_TRAINING_MIN_ROWS = 1000 + +_LOGGER = base.Logger(__name__) + class TabularDataset(datasets._ColumnNamesDataset): """Managed tabular dataset resource for Vertex AI.""" @@ -146,6 +152,112 @@ def create( create_request_timeout=create_request_timeout, ) + @classmethod + def create_from_dataframe( + cls, + df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd' + staging_path: str, + bq_schema: Optional[Union[str, bigquery.SchemaField]] = None, + display_name: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "TabularDataset": + """Creates a new tabular dataset from a Pandas DataFrame. + + Args: + df_source (pd.DataFrame): + Required. Pandas DataFrame containing the source data for + ingestion as a TabularDataset. This method will use the data + types from the provided DataFrame when creating the dataset. + staging_path (str): + Required. The BigQuery table to stage the data + for Vertex. Because Vertex maintains a reference to this source + to create the Vertex Dataset, this BigQuery table should + not be deleted. Example: `bq://my-project.my-dataset.my-table`. + If the provided BigQuery table doesn't exist, this method will + create the table. If the provided BigQuery table already exists, + and the schemas of the BigQuery table and your DataFrame match, + this method will append the data in your local DataFrame to the table. + The location of the provided BigQuery table should conform to the location requirements + specified here: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations. + bq_schema (Optional[Union[str, bigquery.SchemaField]]): + Optional. If not set, BigQuery will autodetect the schema using your DataFrame's column types. + If set, BigQuery will use the schema you provide when creating the staging table. For more details, + see: https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema + display_name (str): + Optional. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 charact + project (str): + Optional. Project to upload this dataset to. Overrides project set in + aiplatform.init. + location (str): + Optional. Location to upload this dataset to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this dataset. Overrides + credentials set in aiplatform.init. + Returns: + tabular_dataset (TabularDataset): + Instantiated representation of the managed tabular dataset resource. + """ + + if staging_path.startswith("bq://"): + bq_staging_path = staging_path[len("bq://") :] + else: + raise ValueError( + "Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`." + ) + + try: + import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery' + except ImportError: + raise ImportError( + "Pyarrow is not installed, and is required to use the BigQuery client." + 'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"' + ) + + if len(df_source) < _AUTOML_TRAINING_MIN_ROWS: + _LOGGER.info( + "Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training." + % (len(df_source), _AUTOML_TRAINING_MIN_ROWS), + ) + + bigquery_client = bigquery.Client( + project=project or initializer.global_config.project, + credentials=credentials or initializer.global_config.credentials, + ) + + try: + parquet_options = bigquery.format_options.ParquetOptions() + parquet_options.enable_list_inference = True + + job_config = bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.PARQUET, + parquet_options=parquet_options, + ) + + if bq_schema: + job_config.schema = bq_schema + + job = bigquery_client.load_table_from_dataframe( + dataframe=df_source, destination=bq_staging_path, job_config=job_config + ) + + job.result() + + finally: + dataset_from_dataframe = cls.create( + display_name=display_name, + bq_source=staging_path, + project=project, + location=location, + credentials=credentials, + ) + + return dataset_from_dataframe + def import_data(self): raise NotImplementedError( f"{self.__class__.__name__} class does not support 'import_data'" diff --git a/setup.py b/setup.py index 398b7ab654..7db8ad5f27 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,9 @@ pipelines_extra_requires = [ "pyyaml>=5.3,<6", ] +datasets_extra_require = [ + "pyarrow >= 3.0.0, < 8.0dev", +] full_extra_require = list( set( tensorboard_extra_require @@ -63,6 +66,7 @@ + lit_extra_require + featurestore_extra_require + pipelines_extra_requires + + datasets_extra_require ) ) testing_extra_require = ( diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index d8d8bd53e3..81b5e420e9 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -20,10 +20,14 @@ import pytest import importlib +import pandas as pd + from google import auth as google_auth from google.api_core import exceptions from google.api_core import client_options +from google.cloud import bigquery + from google.cloud import aiplatform from google.cloud import storage from google.cloud.aiplatform import utils @@ -33,6 +37,8 @@ from test_utils.vpcsc_config import vpcsc_config +from tests.system.aiplatform import e2e_base + # TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported _, _TEST_PROJECT = google_auth.default() TEST_BUCKET = os.environ.get( @@ -55,40 +61,91 @@ _TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" _TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" +# create_from_dataframe +_TEST_BOOL_COL = "bool_col" +_TEST_BOOL_ARR_COL = "bool_array_col" +_TEST_DOUBLE_COL = "double_col" +_TEST_DOUBLE_ARR_COL = "double_array_col" +_TEST_INT_COL = "int64_col" +_TEST_INT_ARR_COL = "int64_array_col" +_TEST_STR_COL = "string_col" +_TEST_STR_ARR_COL = "string_array_col" +_TEST_BYTES_COL = "bytes_col" +_TEST_DF_COLUMN_NAMES = [ + _TEST_BOOL_COL, + _TEST_BOOL_ARR_COL, + _TEST_DOUBLE_COL, + _TEST_DOUBLE_ARR_COL, + _TEST_INT_COL, + _TEST_INT_ARR_COL, + _TEST_STR_COL, + _TEST_STR_ARR_COL, + _TEST_BYTES_COL, +] +_TEST_DATAFRAME = pd.DataFrame( + data=[ + [ + False, + [True, False], + 1.2, + [1.2, 3.4], + 1, + [1, 2], + "test", + ["test1", "test2"], + b"1", + ], + [ + True, + [True, True], + 2.2, + [2.2, 4.4], + 2, + [2, 3], + "test1", + ["test2", "test3"], + b"0", + ], + ], + columns=_TEST_DF_COLUMN_NAMES, +) +_TEST_DATAFRAME_BQ_SCHEMA = [ + bigquery.SchemaField(name="bool_col", field_type="BOOL"), + bigquery.SchemaField(name="bool_array_col", field_type="BOOL", mode="REPEATED"), + bigquery.SchemaField(name="double_col", field_type="FLOAT"), + bigquery.SchemaField(name="double_array_col", field_type="FLOAT", mode="REPEATED"), + bigquery.SchemaField(name="int64_col", field_type="INTEGER"), + bigquery.SchemaField(name="int64_array_col", field_type="INTEGER", mode="REPEATED"), + bigquery.SchemaField(name="string_col", field_type="STRING"), + bigquery.SchemaField(name="string_array_col", field_type="STRING", mode="REPEATED"), + bigquery.SchemaField(name="bytes_col", field_type="STRING"), +] + + +@pytest.mark.usefixtures( + "prepare_staging_bucket", + "delete_staging_bucket", + "prepare_bigquery_dataset", + "delete_bigquery_dataset", + "tear_down_resources", +) +class TestDataset(e2e_base.TestEndToEnd): + + _temp_prefix = "temp-vertex-sdk-dataset-test" -class TestDataset: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - @pytest.fixture() - def shared_state(self): - shared_state = {} - yield shared_state - @pytest.fixture() def create_staging_bucket(self, shared_state): new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" - storage_client = storage.Client() storage_client.create_bucket(new_staging_bucket) shared_state["storage_client"] = storage_client shared_state["staging_bucket"] = new_staging_bucket yield - @pytest.fixture() - def delete_staging_bucket(self, shared_state): - yield - storage_client = shared_state["storage_client"] - - # Delete temp staging bucket - bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"]) - bucket_to_delete.delete(force=True) - - # Close Storage Client - storage_client._http._auth_request.session.close() - storage_client._http.close() - @pytest.fixture() def dataset_gapic_client(self): gapic_client = dataset_service.DatasetServiceClient( @@ -253,6 +310,74 @@ def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): == aiplatform.schema.dataset.metadata.tabular ) + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_tabular_dataset_from_dataframe( + self, dataset_gapic_client, shared_state + ): + """Use the Dataset.create_from_dataframe() method to create a new tabular dataset. + Then confirm the dataset was successfully created and references the BQ source.""" + + assert shared_state["bigquery_dataset"] + + shared_state["resources"] = [] + + bigquery_dataset_id = shared_state["bigquery_dataset_id"] + bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + tabular_dataset = aiplatform.TabularDataset.create_from_dataframe( + df_source=_TEST_DATAFRAME, + staging_path=bq_staging_table, + display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}", + ) + shared_state["resources"].extend([tabular_dataset]) + shared_state["dataset_name"] = tabular_dataset.resource_name + + gapic_metadata = tabular_dataset.to_dict()["metadata"] + bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"] + + assert bq_staging_table == bq_source + assert ( + tabular_dataset.metadata_schema_uri + == aiplatform.schema.dataset.metadata.tabular + ) + + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_tabular_dataset_from_dataframe_with_provided_schema( + self, dataset_gapic_client, shared_state + ): + """Use the Dataset.create_from_dataframe() method to create a new tabular dataset, + passing in the optional `bq_schema` argument. Then confirm the dataset was successfully + created and references the BQ source.""" + + assert shared_state["bigquery_dataset"] + + shared_state["resources"] = [] + + bigquery_dataset_id = shared_state["bigquery_dataset_id"] + bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + tabular_dataset = aiplatform.TabularDataset.create_from_dataframe( + df_source=_TEST_DATAFRAME, + staging_path=bq_staging_table, + display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}", + bq_schema=_TEST_DATAFRAME_BQ_SCHEMA, + ) + shared_state["resources"].extend([tabular_dataset]) + shared_state["dataset_name"] = tabular_dataset.resource_name + + gapic_metadata = tabular_dataset.to_dict()["metadata"] + bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"] + + assert bq_staging_table == bq_source + assert ( + tabular_dataset.metadata_schema_uri + == aiplatform.schema.dataset.metadata.tabular + ) + # TODO(vinnys): Remove pytest skip once persistent resources are accessible @pytest.mark.skip(reason="System tests cannot access persistent test resources") @pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket") diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 5f6241802f..13ef13aebd 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -17,6 +17,8 @@ import os +import pandas as pd + import pytest from unittest import mock @@ -147,6 +149,72 @@ _TEST_LABELS = {"my_key": "my_value"} +# create_from_dataframe +_TEST_INVALID_SOURCE_URI_BQ = "my-project.my-dataset.table" + +_TEST_BOOL_COL = "bool_col" +_TEST_BOOL_ARR_COL = "bool_array_col" +_TEST_DOUBLE_COL = "double_col" +_TEST_DOUBLE_ARR_COL = "double_array_col" +_TEST_INT_COL = "int64_col" +_TEST_INT_ARR_COL = "int64_array_col" +_TEST_STR_COL = "string_col" +_TEST_STR_ARR_COL = "string_array_col" +_TEST_BYTES_COL = "bytes_col" +_TEST_DF_COLUMN_NAMES = [ + _TEST_BOOL_COL, + _TEST_BOOL_ARR_COL, + _TEST_DOUBLE_COL, + _TEST_DOUBLE_ARR_COL, + _TEST_INT_COL, + _TEST_INT_ARR_COL, + _TEST_STR_COL, + _TEST_STR_ARR_COL, + _TEST_BYTES_COL, +] +_TEST_DATAFRAME = pd.DataFrame( + data=[ + [ + False, + [True, False], + 1.2, + [1.2, 3.4], + 1, + [1, 2], + "test", + ["test1", "test2"], + b"1", + ], + [ + True, + [True, True], + 2.2, + [2.2, 4.4], + 2, + [2, 3], + "test1", + ["test2", "test3"], + b"0", + ], + ], + columns=_TEST_DF_COLUMN_NAMES, +) +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_DATAFRAME_BQ_SCHEMA = [ + bigquery.SchemaField(name="bool_col", field_type="BOOL"), + bigquery.SchemaField(name="bool_array_col", field_type="BOOL", mode="REPEATED"), + bigquery.SchemaField(name="double_col", field_type="FLOAT"), + bigquery.SchemaField(name="double_array_col", field_type="FLOAT", mode="REPEATED"), + bigquery.SchemaField(name="int64_col", field_type="INTEGER"), + bigquery.SchemaField(name="int64_array_col", field_type="INTEGER", mode="REPEATED"), + bigquery.SchemaField(name="string_col", field_type="STRING"), + bigquery.SchemaField(name="string_array_col", field_type="STRING", mode="REPEATED"), + bigquery.SchemaField(name="bytes_col", field_type="STRING"), +] +_TEST_DATAFRAME_INVALID_BQ_SCHEMA = [ + bigquery.SchemaField(name="bool_col", field_type="BOOL"), +] + @pytest.fixture def get_dataset_mock(): @@ -1438,6 +1506,115 @@ def test_create_dataset_with_labels(self, create_dataset_mock, sync): timeout=None, ) + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") + @pytest.mark.parametrize( + "source_df", + [_TEST_DATAFRAME], + ) + def test_create_dataset_tabular_from_dataframe( + self, + create_dataset_mock, + source_df, + bq_client_mock, + ): + + aiplatform.init( + project=_TEST_PROJECT, + credentials=_TEST_CREDENTIALS, + ) + dataset_from_df = datasets.TabularDataset.create_from_dataframe( + display_name=_TEST_DISPLAY_NAME, + df_source=source_df, + staging_path=_TEST_SOURCE_URI_BQ, + ) + + dataset_from_df.wait() + + assert dataset_from_df.metadata_schema_uri == _TEST_METADATA_SCHEMA_URI_TABULAR + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + timeout=None, + ) + + assert bq_client_mock.call_args_list[0] == mock.call( + project=_TEST_PROJECT, + credentials=_TEST_CREDENTIALS, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") + @pytest.mark.parametrize( + "source_df", + [_TEST_DATAFRAME], + ) + def test_create_dataset_tabular_from_dataframe_with_schema( + self, + create_dataset_mock, + source_df, + bq_client_mock, + ): + + aiplatform.init( + project=_TEST_PROJECT, + credentials=_TEST_CREDENTIALS, + ) + + dataset_from_df = datasets.TabularDataset.create_from_dataframe( + display_name=_TEST_DISPLAY_NAME, + df_source=source_df, + staging_path=_TEST_SOURCE_URI_BQ, + bq_schema=_TEST_DATAFRAME_BQ_SCHEMA, + ) + + dataset_from_df.wait() + + assert dataset_from_df.metadata_schema_uri == _TEST_METADATA_SCHEMA_URI_TABULAR + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + timeout=None, + ) + + assert bq_client_mock.call_args_list[0] == mock.call( + project=_TEST_PROJECT, + credentials=_TEST_CREDENTIALS, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock") + @pytest.mark.parametrize( + "source_df", + [_TEST_DATAFRAME], + ) + def test_create_dataset_tabular_from_dataframe_with_invalid_bq_uri( + self, + create_dataset_mock, + source_df, + bq_client_mock, + ): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.TabularDataset.create_from_dataframe( + display_name=_TEST_DISPLAY_NAME, + df_source=source_df, + staging_path=_TEST_INVALID_SOURCE_URI_BQ, + ) + class TestTextDataset: def setup_method(self):