Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Pandas DataFrame support to TabularDataset #1185

Merged
merged 12 commits into from
May 5, 2022
97 changes: 97 additions & 0 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@

sararob marked this conversation as resolved.
Show resolved Hide resolved
from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils

_AUTOML_TRAINING_MIN_ROWS = 1000

_LOGGER = base.Logger(__name__)


class TabularDataset(datasets._ColumnNamesDataset):
"""Managed tabular dataset resource for Vertex AI."""
Expand Down Expand Up @@ -146,6 +152,97 @@ def create(
create_request_timeout=create_request_timeout,
)

@classmethod
def create_from_dataframe(
cls,
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
staging_path: str,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "TabularDataset":
"""Creates a new tabular dataset from a Pandas DataFrame.

Args:ers.
df_source (pd.DataFrame):
Required. Pandas DataFrame containing the source data for
ingestion as a TabularDataset.
staging_path (str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the location requirements should also be documented or a reference to this documentation should be provided: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations

If possible they should be validated but not a hard requirement. Is it possible for the dataset create to fail because of the regional requirements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I just tested it and it can fail if the dataset location doesn't match the project location or the service doesn't have the right access to the dataset. I'll update the docstring to link to that page.

In terms of validating, the BQ client throws this error: google.api_core.exceptions.FailedPrecondition: 400 BigQuery Dataset location eu must be in the same location as the service location us-central1.

Do you think we should validate as well or let the BQ client handle validation? If we do validation, we'd need to use the BQ client to check the location of the provided BQ dataset string.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with relying on BQ client.

Required. The BigQuery table to stage the data
for Vertex. Because Vertex maintains a reference to this source
to create the Vertex Dataset, this BigQuery table should
not be deleted. Example: `bq://my-project.my-dataset.my-table`.
If the provided BigQuery table doesn't exist, this method will
create the table. If the provided BigQuery table already exists,
and the schemas of the BigQuery table and your DataFrame match,
this method will append the data in your local DataFrame to the table.
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 charact
project (str):
Optional. Project to upload this dataset to. Overrides project set in
aiplatform.init.
location (str):
Optional. Location to upload this dataset to. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to upload this dataset. Overrides
credentials set in aiplatform.init.
"""

if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
_LOGGER.info(
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
)

try:
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
except ImportError:
raise ImportError(
"Pyarrow is not installed. Please install pyarrow to use the BigQuery client."
sararob marked this conversation as resolved.
Show resolved Hide resolved
)

bigquery_client = bigquery.Client(
project=project or initializer.global_config.project,
credentials=credentials or initializer.global_config.credentials,
)

if staging_path.startswith("bq://"):
sararob marked this conversation as resolved.
Show resolved Hide resolved
bq_staging_path = staging_path[len("bq://") :]
else:
raise ValueError(
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
)

try:
parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True

job_config = bigquery.LoadJobConfig(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this config infer all the types? I see the enable_list_inference but I couldn't find a reference in the BQ docs for non list type inference.

Copy link
Contributor Author

@sararob sararob Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will infer the data types from the DF. From the BQ docs:

            The destination table to use for loading the data. If it is an
            existing table, the schema of the :class:`~pandas.DataFrame`
            must match the schema of the destination table. If the table
            does not yet exist, the schema is inferred from the
            :class:`~pandas.DataFrame`.

I added a bq_schema param to give the user the option to override the data type autodection, but I think I may need to add more client-side validation on that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rely on BQ client validation.

source_format=bigquery.SourceFormat.PARQUET,
parquet_options=parquet_options,
)

job = bigquery_client.load_table_from_dataframe(
dataframe=df_source, destination=bq_staging_path, job_config=job_config
)

job.result()

finally:
dataset_from_dataframe = cls.create(
display_name=display_name,
bq_source=staging_path,
project=project,
location=location,
credentials=credentials,
)

return dataset_from_dataframe

def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
Expand Down
133 changes: 133 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
sararob marked this conversation as resolved.
Show resolved Hide resolved
from importlib import reload
from unittest.mock import patch
import pandas as pd
sararob marked this conversation as resolved.
Show resolved Hide resolved

from google.api_core import operation
from google.auth.exceptions import GoogleAuthError
Expand Down Expand Up @@ -147,6 +148,30 @@

_TEST_LABELS = {"my_key": "my_value"}

# create_from_dataframe
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preference for an integration test as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added 2 integration tests: one with the bq_schema param and one without.

_TEST_INVALID_SOURCE_URI_BQ = "my-project.my-dataset.table"

_TEST_BOOL_COL = "bool_col"
_TEST_BOOL_ARR_COL = "bool_array_col"
_TEST_DOUBLE_COL = "double_col"
_TEST_DOUBLE_ARR_COL = "double_array_col"
_TEST_INT_COL = "int64_col"
_TEST_INT_ARR_COL = "int64_array_col"
_TEST_STR_COL = "string_col"
_TEST_STR_ARR_COL = "string_array_col"
_TEST_BYTES_COL = "bytes_col"
_TEST_DF_COLUMN_NAMES = [
_TEST_BOOL_COL,
_TEST_BOOL_ARR_COL,
_TEST_DOUBLE_COL,
_TEST_DOUBLE_ARR_COL,
_TEST_INT_COL,
_TEST_INT_ARR_COL,
_TEST_STR_COL,
_TEST_STR_ARR_COL,
_TEST_BYTES_COL,
]


@pytest.fixture
def get_dataset_mock():
Expand Down Expand Up @@ -1378,6 +1403,114 @@ def test_create_dataset_with_labels(self, create_dataset_mock, sync):
timeout=None,
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
@pytest.mark.parametrize(
"source_df",
[
pd.DataFrame(
sararob marked this conversation as resolved.
Show resolved Hide resolved
data=[
[
False,
[True, False],
1.2,
[1.2, 3.4],
1,
[1, 2],
"test",
["test1", "test2"],
b"1",
],
[
True,
[True, True],
2.2,
[2.2, 4.4],
2,
[2, 3],
"test1",
["test2", "test3"],
b"0",
],
],
columns=_TEST_DF_COLUMN_NAMES,
),
],
)
@pytest.mark.parametrize("sync", [True, False])
def test_create_dataset_tabular_from_dataframe(
self, create_dataset_mock, source_df, bq_client_mock, sync
sararob marked this conversation as resolved.
Show resolved Hide resolved
):
aiplatform.init(project=_TEST_PROJECT)

dataset_from_df = datasets.TabularDataset.create_from_dataframe(
display_name=_TEST_DISPLAY_NAME,
df_source=source_df,
staging_path=_TEST_SOURCE_URI_BQ,
)

if not sync:
dataset_from_df.wait()

assert dataset_from_df.metadata_schema_uri == _TEST_METADATA_SCHEMA_URI_TABULAR

sararob marked this conversation as resolved.
Show resolved Hide resolved
expected_dataset = gca_dataset.Dataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR,
metadata=_TEST_METADATA_TABULAR_BQ,
)

create_dataset_mock.assert_called_once_with(
parent=_TEST_PARENT,
dataset=expected_dataset,
metadata=_TEST_REQUEST_METADATA,
timeout=None,
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
@pytest.mark.parametrize(
"source_df",
[
pd.DataFrame(
data=[
[
False,
[True, False],
1.2,
[1.2, 3.4],
1,
[1, 2],
"test",
["test1", "test2"],
b"1",
],
[
True,
[True, True],
2.2,
[2.2, 4.4],
2,
[2, 3],
"test1",
["test2", "test3"],
b"0",
],
],
columns=_TEST_DF_COLUMN_NAMES,
),
],
)
@pytest.mark.parametrize("sync", [True, False])
def test_create_dataset_tabular_from_dataframe_with_invalid_bq_uri(
self, create_dataset_mock, source_df, bq_client_mock, sync
sararob marked this conversation as resolved.
Show resolved Hide resolved
):
aiplatform.init(project=_TEST_PROJECT)
with pytest.raises(ValueError):
datasets.TabularDataset.create_from_dataframe(
display_name=_TEST_DISPLAY_NAME,
df_source=source_df,
staging_path=_TEST_INVALID_SOURCE_URI_BQ,
)


class TestTextDataset:
def setup_method(self):
Expand Down