Skip to content

Commit

Permalink
feat: Add support of newly added fields of ExportData API to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595178001
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 2, 2024
1 parent 4d98c55 commit ec3ea30
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 18 deletions.
236 changes: 218 additions & 18 deletions google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from google.api_core import operation
from google.auth import credentials as auth_credentials
Expand All @@ -27,11 +27,13 @@
from google.cloud.aiplatform.compat.services import dataset_service_client
from google.cloud.aiplatform.compat.types import (
dataset as gca_dataset,
dataset_service as gca_dataset_service,
encryption_spec as gca_encryption_spec,
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
from google.protobuf import field_mask_pb2
from google.protobuf import json_format

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -561,6 +563,120 @@ def import_data(
)
return self

def _validate_and_convert_export_split(
self,
split: Union[Dict[str, str], Dict[str, float]],
) -> Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]:
"""
Validates the split for data export. Valid splits are dicts
encoding the contents of proto messages ExportFilterSplit or
ExportFractionSplit. If the split is valid, this function returns
the corresponding convertered proto message.
split (Union[Dict[str, str], Dict[str, float]]):
The instructions how the export data should be split between the
training, validation and test sets.
"""
if len(split) != 3:
raise ValueError(
"The provided split for data export does not provide enough"
"information. It must have three fields, mapping to training,"
"validation and test splits respectively."
)

if not ("training_filter" in split or "training_fraction" in split):
raise ValueError(
"The provided filter for data export does not provide enough"
"information. It must have three fields, mapping to training,"
"validation and test respectively."
)

if "training_filter" in split:
if (
"validation_filter" in split
and "test_filter" in split
and split["training_filter"] is str
and split["validation_filter"] is str
and split["test_filter"] is str
):
return gca_dataset.ExportFilterSplit(
training_filter=split["training_filter"],
validation_filter=split["validation_filter"],
test_filter=split["test_filter"],
)
else:
raise ValueError(
"The provided ExportFilterSplit does not contain all"
"three required fields: training_filter, "
"validation_filter and test_filter."
)
else:
if (
"validation_fraction" in split
and "test_fraction" in split
and split["training_fraction"] is float
and split["validation_fraction"] is float
and split["test_fraction"] is float
):
return gca_dataset.ExportFractionSplit(
training_fraction=split["training_fraction"],
validation_fraction=split["validation_fraction"],
test_fraction=split["test_fraction"],
)
else:
raise ValueError(
"The provided ExportFractionSplit does not contain all"
"three required fields: training_fraction, "
"validation_fraction and test_fraction."
)

def _get_completed_export_data_operation(
self,
output_dir: str,
export_use: Optional[gca_dataset.ExportDataConfig.ExportUse] = None,
annotation_filter: Optional[str] = None,
saved_query_id: Optional[str] = None,
annotation_schema_uri: Optional[str] = None,
split: Optional[
Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]
] = None,
) -> gca_dataset_service.ExportDataResponse:
self.wait()

# TODO(b/171311614): Add support for BigQuery export path
export_data_config = gca_dataset.ExportDataConfig(
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
)
if export_use is not None:
export_data_config.export_use = export_use
if annotation_filter is not None:
export_data_config.annotation_filter = annotation_filter
if saved_query_id is not None:
export_data_config.saved_query_id = saved_query_id
if annotation_schema_uri is not None:
export_data_config.annotation_schema_uri = annotation_schema_uri
if split is not None:
if isinstance(split, gca_dataset.ExportFilterSplit):
export_data_config.filter_split = split
elif isinstance(split, gca_dataset.ExportFractionSplit):
export_data_config.fraction_split = split

_LOGGER.log_action_start_against_resource("Exporting", "data", self)

export_lro = self.api_client.export_data(
name=self.resource_name, export_config=export_data_config
)

_LOGGER.log_action_started_against_resource_with_lro(
"Export", "data", self.__class__, export_lro
)

export_data_response = export_lro.result()

_LOGGER.log_action_completed_against_resource("data", "export", self)

return export_data_response

# TODO(b/174751568) add optional sync support
def export_data(self, output_dir: str) -> Sequence[str]:
"""Exports data to output dir to GCS.
Expand All @@ -585,29 +701,113 @@ def export_data(self, output_dir: str) -> Sequence[str]:
exported_files (Sequence[str]):
All of the files that are exported in this export operation.
"""
self.wait()
return self._get_completed_export_data_operation(output_dir).exported_files

# TODO(b/171311614): Add support for BigQuery export path
export_data_config = gca_dataset.ExportDataConfig(
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
)
def export_data_for_custom_training(
self,
output_dir: str,
annotation_filter: Optional[str] = None,
saved_query_id: Optional[str] = None,
annotation_schema_uri: Optional[str] = None,
split: Optional[Union[Dict[str, str], Dict[str, float]]] = None,
) -> Dict[str, Any]:
"""Exports data to output dir to GCS for custom training use case.
Example annotation_schema_uri (image classification):
gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
Example split (filter split):
{
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
}
Example split (fraction split):
{
"training_fraction": 0.7,
"validation_fraction": 0.2,
"test_fraction": 0.1,
}
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
Args:
output_dir (str):
Required. The Google Cloud Storage location where the output is to
be written to. In the given directory a new directory will be
created with name:
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
where timestamp is in YYYYMMDDHHMMSS format. All export
output will be written into that directory. Inside that
directory, annotations with the same schema will be grouped
into sub directories which are named with the corresponding
annotations' schema title. Inside these sub directories, a
schema.yaml will be created to describe the output format.
export_lro = self.api_client.export_data(
name=self.resource_name, export_config=export_data_config
)
If the uri doesn't end with '/', a '/' will be automatically
appended. The directory is created if it doesn't exist.
annotation_filter (str):
Optional. An expression for filtering what part of the Dataset
is to be exported.
Only Annotations that match this filter will be exported.
The filter syntax is the same as in
[ListAnnotations][DatasetService.ListAnnotations].
saved_query_id (str):
Optional. The ID of a SavedQuery (annotation set) under this
Dataset used for filtering Annotations for training.
Only used for custom training data export use cases.
Only applicable to Datasets that have SavedQueries.
Only Annotations that are associated with this SavedQuery are
used in respectively training. When used in conjunction with
annotations_filter, the Annotations used for training are
filtered by both saved_query_id and annotations_filter.
Only one of saved_query_id and annotation_schema_uri should be
specified as both of them represent the same thing: problem
type.
annotation_schema_uri (str):
Optional. The Cloud Storage URI that points to a YAML file
describing the annotation schema. The schema is defined as an
OpenAPI 3.0.2 Schema Object. The schema files that can be used
here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/, note
that the chosen schema must be consistent with
metadata_schema_uri of this Dataset.
Only used for custom training data export use cases.
Only applicable if this Dataset that have DataItems and
Annotations.
Only Annotations that both match this schema and belong to
DataItems not ignored by the split method are used in
respectively training, validation or test role, depending on the
role of the DataItem they are on.
When used in conjunction with annotations_filter, the
Annotations used for training are filtered by both
annotations_filter and annotation_schema_uri.
split (Union[Dict[str, str], Dict[str, float]]):
The instructions how the export data should be split between the
training, validation and test sets.
_LOGGER.log_action_started_against_resource_with_lro(
"Export", "data", self.__class__, export_lro
Returns:
export_data_response (Dict):
Response message for DatasetService.ExportData in Dictionary
format.
"""
split = self._validate_and_convert_export_split(split)

return json_format.MessageToDict(
self._get_completed_export_data_operation(
output_dir,
gca_dataset.ExportDataConfig.ExportUse.CUSTOM_CODE_TRAINING,
annotation_filter,
saved_query_id,
annotation_schema_uri,
split,
)
)

export_data_response = export_lro.result()

_LOGGER.log_action_completed_against_resource("data", "export", self)

return export_data_response.exported_files

def update(
self,
*,
Expand Down
27 changes: 27 additions & 0 deletions tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,33 @@ def test_export_data(self, storage_client, staging_bucket):

assert blob # Verify the returned GCS export path exists

def test_export_data_for_custom_training(self, staging_bucket):
"""Get an existing dataset, export data to a newly created folder in
Google Cloud Storage, then verify data was successfully exported."""

# pylint: disable=protected-access
# Custom training data export should be generic, hence using the base
# _Dataset class here in test. In practice, users shuold be able to
# use this function in any inhericted classes of _Dataset.
dataset = aiplatform._Dataset(dataset_name=_TEST_TEXT_DATASET_ID)

split = {
"training_fraction": 0.6,
"validation_fraction": 0.2,
"test_fraction": 0.2,
}

export_data_response = dataset.export_data_for_custom_training(
output_dir=f"gs://{staging_bucket.name}",
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml",
split=split,
)

# Ensure three output paths (training, validation and test) are provided
assert len(export_data_response["exported_files"]) == 3
# Ensure data stats are calculated and present
assert export_data_response["data_stats"]["training_data_items_count"] > 0

def test_update_dataset(self):
"""Create a new dataset and use update() method to change its display_name, labels, and description.
Then confirm these fields of the dataset was successfully modifed."""
Expand Down

0 comments on commit ec3ea30

Please sign in to comment.