Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Pandas DataFrame support to TabularDataset #1185

Merged
merged 12 commits into from
May 5, 2022
114 changes: 113 additions & 1 deletion google/cloud/aiplatform/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,12 +19,18 @@

sararob marked this conversation as resolved.
Show resolved Hide resolved
from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils

_AUTOML_TRAINING_MIN_ROWS = 1000

_LOGGER = base.Logger(__name__)


class TabularDataset(datasets._ColumnNamesDataset):
"""Managed tabular dataset resource for Vertex AI."""
Expand Down Expand Up @@ -146,6 +152,112 @@ def create(
create_request_timeout=create_request_timeout,
)

@classmethod
def create_from_dataframe(
cls,
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
staging_path: str,
bq_schema: Optional[Union[str, bigquery.SchemaField]] = None,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "TabularDataset":
"""Creates a new tabular dataset from a Pandas DataFrame.

Args:
df_source (pd.DataFrame):
Required. Pandas DataFrame containing the source data for
ingestion as a TabularDataset. This method will use the data
types from the provided DataFrame when creating the dataset.
staging_path (str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the location requirements should also be documented or a reference to this documentation should be provided: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations

If possible they should be validated but not a hard requirement. Is it possible for the dataset create to fail because of the regional requirements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I just tested it and it can fail if the dataset location doesn't match the project location or the service doesn't have the right access to the dataset. I'll update the docstring to link to that page.

In terms of validating, the BQ client throws this error: google.api_core.exceptions.FailedPrecondition: 400 BigQuery Dataset location eu must be in the same location as the service location us-central1.

Do you think we should validate as well or let the BQ client handle validation? If we do validation, we'd need to use the BQ client to check the location of the provided BQ dataset string.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with relying on BQ client.

Required. The BigQuery table to stage the data
for Vertex. Because Vertex maintains a reference to this source
to create the Vertex Dataset, this BigQuery table should
not be deleted. Example: `bq://my-project.my-dataset.my-table`.
If the provided BigQuery table doesn't exist, this method will
create the table. If the provided BigQuery table already exists,
and the schemas of the BigQuery table and your DataFrame match,
this method will append the data in your local DataFrame to the table.
The location of the provided BigQuery table should conform to the location requirements
specified here: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations.
bq_schema (Optional[Union[str, bigquery.SchemaField]]):
Optional. If not set, BigQuery will autodetect the schema using your DataFrame's column types.
If set, BigQuery will use the schema you provide when creating the staging table. For more details,
see: https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 charact
project (str):
Optional. Project to upload this dataset to. Overrides project set in
aiplatform.init.
location (str):
Optional. Location to upload this dataset to. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to upload this dataset. Overrides
credentials set in aiplatform.init.
Returns:
tabular_dataset (TabularDataset):
Instantiated representation of the managed tabular dataset resource.
"""

if staging_path.startswith("bq://"):
bq_staging_path = staging_path[len("bq://") :]
else:
raise ValueError(
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
)

try:
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
except ImportError:
raise ImportError(
"Pyarrow is not installed, and is required to use the BigQuery client."
'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
)

if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
_LOGGER.info(
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
)

bigquery_client = bigquery.Client(
project=project or initializer.global_config.project,
credentials=credentials or initializer.global_config.credentials,
)

try:
parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True

job_config = bigquery.LoadJobConfig(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this config infer all the types? I see the enable_list_inference but I couldn't find a reference in the BQ docs for non list type inference.

Copy link
Contributor Author

@sararob sararob Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will infer the data types from the DF. From the BQ docs:

            The destination table to use for loading the data. If it is an
            existing table, the schema of the :class:`~pandas.DataFrame`
            must match the schema of the destination table. If the table
            does not yet exist, the schema is inferred from the
            :class:`~pandas.DataFrame`.

I added a bq_schema param to give the user the option to override the data type autodection, but I think I may need to add more client-side validation on that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rely on BQ client validation.

source_format=bigquery.SourceFormat.PARQUET,
parquet_options=parquet_options,
)

if bq_schema:
job_config.schema = bq_schema

job = bigquery_client.load_table_from_dataframe(
dataframe=df_source, destination=bq_staging_path, job_config=job_config
)

job.result()

finally:
dataset_from_dataframe = cls.create(
display_name=display_name,
bq_source=staging_path,
project=project,
location=location,
credentials=credentials,
)

return dataset_from_dataframe

def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
pipelines_extra_requires = [
"pyyaml>=5.3,<6",
]
datasets_extra_require = [
"pyarrow >= 3.0.0, < 8.0dev",
]
full_extra_require = list(
set(
tensorboard_extra_require
Expand All @@ -63,6 +66,7 @@
+ lit_extra_require
+ featurestore_extra_require
+ pipelines_extra_requires
+ datasets_extra_require
)
)
testing_extra_require = (
Expand Down
165 changes: 145 additions & 20 deletions tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import pytest
import importlib

import pandas as pd

from google import auth as google_auth
from google.api_core import exceptions
from google.api_core import client_options

from google.cloud import bigquery

from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform import utils
Expand All @@ -33,6 +37,8 @@

from test_utils.vpcsc_config import vpcsc_config

from tests.system.aiplatform import e2e_base

# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported
_, _TEST_PROJECT = google_auth.default()
TEST_BUCKET = os.environ.get(
Expand All @@ -55,40 +61,91 @@
_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml"
_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml"

# create_from_dataframe
_TEST_BOOL_COL = "bool_col"
_TEST_BOOL_ARR_COL = "bool_array_col"
_TEST_DOUBLE_COL = "double_col"
_TEST_DOUBLE_ARR_COL = "double_array_col"
_TEST_INT_COL = "int64_col"
_TEST_INT_ARR_COL = "int64_array_col"
_TEST_STR_COL = "string_col"
_TEST_STR_ARR_COL = "string_array_col"
_TEST_BYTES_COL = "bytes_col"
_TEST_DF_COLUMN_NAMES = [
_TEST_BOOL_COL,
_TEST_BOOL_ARR_COL,
_TEST_DOUBLE_COL,
_TEST_DOUBLE_ARR_COL,
_TEST_INT_COL,
_TEST_INT_ARR_COL,
_TEST_STR_COL,
_TEST_STR_ARR_COL,
_TEST_BYTES_COL,
]
_TEST_DATAFRAME = pd.DataFrame(
data=[
[
False,
[True, False],
1.2,
[1.2, 3.4],
1,
[1, 2],
"test",
["test1", "test2"],
b"1",
],
[
True,
[True, True],
2.2,
[2.2, 4.4],
2,
[2, 3],
"test1",
["test2", "test3"],
b"0",
],
],
columns=_TEST_DF_COLUMN_NAMES,
)
_TEST_DATAFRAME_BQ_SCHEMA = [
bigquery.SchemaField(name="bool_col", field_type="BOOL"),
bigquery.SchemaField(name="bool_array_col", field_type="BOOL", mode="REPEATED"),
bigquery.SchemaField(name="double_col", field_type="FLOAT"),
bigquery.SchemaField(name="double_array_col", field_type="FLOAT", mode="REPEATED"),
bigquery.SchemaField(name="int64_col", field_type="INTEGER"),
bigquery.SchemaField(name="int64_array_col", field_type="INTEGER", mode="REPEATED"),
bigquery.SchemaField(name="string_col", field_type="STRING"),
bigquery.SchemaField(name="string_array_col", field_type="STRING", mode="REPEATED"),
bigquery.SchemaField(name="bytes_col", field_type="STRING"),
]


@pytest.mark.usefixtures(
"prepare_staging_bucket",
"delete_staging_bucket",
"prepare_bigquery_dataset",
"delete_bigquery_dataset",
"tear_down_resources",
)
class TestDataset(e2e_base.TestEndToEnd):

_temp_prefix = "temp-vertex-sdk-dataset-test"

class TestDataset:
def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)

@pytest.fixture()
def shared_state(self):
shared_state = {}
yield shared_state

@pytest.fixture()
def create_staging_bucket(self, shared_state):
new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}"

storage_client = storage.Client()
storage_client.create_bucket(new_staging_bucket)
shared_state["storage_client"] = storage_client
shared_state["staging_bucket"] = new_staging_bucket
yield

@pytest.fixture()
def delete_staging_bucket(self, shared_state):
yield
storage_client = shared_state["storage_client"]

# Delete temp staging bucket
bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"])
bucket_to_delete.delete(force=True)

# Close Storage Client
storage_client._http._auth_request.session.close()
storage_client._http.close()

@pytest.fixture()
def dataset_gapic_client(self):
gapic_client = dataset_service.DatasetServiceClient(
Expand Down Expand Up @@ -253,6 +310,74 @@ def test_create_tabular_dataset(self, dataset_gapic_client, shared_state):
== aiplatform.schema.dataset.metadata.tabular
)

@pytest.mark.usefixtures("delete_new_dataset")
def test_create_tabular_dataset_from_dataframe(
self, dataset_gapic_client, shared_state
):
"""Use the Dataset.create_from_dataframe() method to create a new tabular dataset.
Then confirm the dataset was successfully created and references the BQ source."""

assert shared_state["bigquery_dataset"]

shared_state["resources"] = []

bigquery_dataset_id = shared_state["bigquery_dataset_id"]
bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}"

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

tabular_dataset = aiplatform.TabularDataset.create_from_dataframe(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tabular_dataset should be appended to shared_state['resources'] so it's deleted after the test

see example: https://github.com/googleapis/python-aiplatform/blob/main/tests/system/aiplatform/test_e2e_tabular.py#L89

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated both new tests to use shared_state['resources']

df_source=_TEST_DATAFRAME,
staging_path=bq_staging_table,
display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}",
)
shared_state["resources"].extend([tabular_dataset])
shared_state["dataset_name"] = tabular_dataset.resource_name

gapic_metadata = tabular_dataset.to_dict()["metadata"]
bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"]

assert bq_staging_table == bq_source
assert (
tabular_dataset.metadata_schema_uri
== aiplatform.schema.dataset.metadata.tabular
)

@pytest.mark.usefixtures("delete_new_dataset")
def test_create_tabular_dataset_from_dataframe_with_provided_schema(
self, dataset_gapic_client, shared_state
):
"""Use the Dataset.create_from_dataframe() method to create a new tabular dataset,
passing in the optional `bq_schema` argument. Then confirm the dataset was successfully
created and references the BQ source."""

assert shared_state["bigquery_dataset"]

shared_state["resources"] = []

bigquery_dataset_id = shared_state["bigquery_dataset_id"]
bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}"

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

tabular_dataset = aiplatform.TabularDataset.create_from_dataframe(
df_source=_TEST_DATAFRAME,
staging_path=bq_staging_table,
display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}",
bq_schema=_TEST_DATAFRAME_BQ_SCHEMA,
)
shared_state["resources"].extend([tabular_dataset])
shared_state["dataset_name"] = tabular_dataset.resource_name

gapic_metadata = tabular_dataset.to_dict()["metadata"]
bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"]

assert bq_staging_table == bq_source
assert (
tabular_dataset.metadata_schema_uri
== aiplatform.schema.dataset.metadata.tabular
)

# TODO(vinnys): Remove pytest skip once persistent resources are accessible
@pytest.mark.skip(reason="System tests cannot access persistent test resources")
@pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket")
Expand Down
Loading