Skip to content

Commit

Permalink
Moved transformations code out of column_names_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Aug 30, 2021
1 parent 3a6e089 commit 49673fe
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 93 deletions.
93 changes: 3 additions & 90 deletions google/cloud/aiplatform/datasets/column_names_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@
import csv
import logging

from typing import List, Optional, Set, Union
from typing import List, Optional, Set

from google.auth import credentials as auth_credentials

Expand All @@ -28,16 +28,12 @@

from google.cloud.aiplatform import utils

from typing import Dict, List, Optional, Tuple
from typing import List, Optional

from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.datasets import _Dataset

import warnings

_LOGGER = base.Logger(__name__)


class _ColumnNamesDataset(_Dataset):
@property
Expand Down Expand Up @@ -259,86 +255,3 @@ def _retrieve_bq_source_columns(
field
)
}

def _get_default_column_transformations(
self, target_column: str,
) -> Tuple[Dict, List[str]]:
"""Get default column transformations from the column names, while omitting the target column.
Args:
target_column (str):
Required. The name of the column values of which the Model is to predict.
Returns:
Dict
The default column transformations.
"""

column_names = [
column_name
for column_name in self.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]

return (column_transformations, column_names)

@staticmethod
def _validate_and_get_column_transformations(
column_specs: Optional[Dict[str, str]],
column_transformations: Optional[Union[Dict, List[Dict]]],
) -> Dict:
"""Validates column specs and transformations, then returns processed transformations.
Args:
column_specs (Dict[str, str]):
Optional. Alternative to column_transformations where the keys of the dict
are column names and their respective values are one of
AutoMLTabularTrainingJob.column_data_types.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
column_transformations (Union[Dict, List[Dict]]):
Optional. Transformations to apply to the input columns (i.e. columns other
than the targetColumn). Each transformation may produce multiple
result values from the column's value, and all are used for training.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
Consider using column_specs as column_transformations will be deprecated eventually.
Returns:
Dict
The column transformations.
"""
# user populated transformations
if column_transformations is not None and column_specs is not None:
raise ValueError(
"Both column_transformations and column_specs were passed. Only one is allowed."
)
if column_transformations is not None:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"consider using column_specs instead. column_transformations will be deprecated in the future.",
DeprecationWarning,
stacklevel=2,
)

return column_transformations
elif column_specs is not None:
return [
{transformation: {"column_name": column_name}}
for column_name, transformation in column_specs.items()
]
else:
return None
9 changes: 6 additions & 3 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from google.cloud.aiplatform.utils import _timestamped_gcs_dir
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils
from google.cloud.aiplatform.utils import column_transformations_utils

from google.cloud.aiplatform.v1.schema.trainingjob import (
definition_v1 as training_job_inputs,
Expand Down Expand Up @@ -3148,7 +3149,7 @@ def __init__(
model_encryption_spec_key_name=model_encryption_spec_key_name,
)

self._column_transformations = datasets._ColumnNamesDataset._validate_and_get_column_transformations(
self._column_transformations = column_transformations_utils.validate_and_get_column_transformations(
column_specs, column_transformations
)

Expand Down Expand Up @@ -3510,7 +3511,9 @@ def _run(
(
self._column_transformations,
column_names,
) = dataset._get_default_column_transformations(target_column)
) = column_transformations_utils.get_default_column_transformations(
dataset=dataset, target_column=target_column
)

_LOGGER.info(
"The column transformation of type 'auto' was set for the following columns: %s."
Expand Down Expand Up @@ -3730,7 +3733,7 @@ def __init__(
model_encryption_spec_key_name=model_encryption_spec_key_name,
)

self._column_transformations = datasets._ColumnNamesDataset._validate_and_get_column_transformations(
self._column_transformations = column_transformations_utils.validate_and_get_column_transformations(
column_specs, column_transformations
)

Expand Down
107 changes: 107 additions & 0 deletions google/cloud/aiplatform/utils/column_transformations_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from google.cloud.aiplatform import base

from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
from typing import Dict, List, Optional, Union, Tuple

import warnings


def get_default_column_transformations(
dataset: _ColumnNamesDataset, target_column: str,
) -> Tuple[Dict, List[str]]:
"""Get default column transformations from the column names, while omitting the target column.
Args:
target_column (str):
Required. The name of the column values of which the Model is to predict.
Returns:
Dict
The default column transformations.
"""

column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]

return (column_transformations, column_names)


def validate_and_get_column_transformations(
column_specs: Optional[Dict[str, str]],
column_transformations: Optional[Union[Dict, List[Dict]]],
) -> Dict:
"""Validates column specs and transformations, then returns processed transformations.
Args:
column_specs (Dict[str, str]):
Optional. Alternative to column_transformations where the keys of the dict
are column names and their respective values are one of
AutoMLTabularTrainingJob.column_data_types.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
column_transformations (Union[Dict, List[Dict]]):
Optional. Transformations to apply to the input columns (i.e. columns other
than the targetColumn). Each transformation may produce multiple
result values from the column's value, and all are used for training.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
Consider using column_specs as column_transformations will be deprecated eventually.
Returns:
Dict
The column transformations.
"""
# user populated transformations
if column_transformations is not None and column_specs is not None:
raise ValueError(
"Both column_transformations and column_specs were passed. Only one is allowed."
)
if column_transformations is not None:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"consider using column_specs instead. column_transformations will be deprecated in the future.",
DeprecationWarning,
stacklevel=2,
)

return column_transformations
elif column_specs is not None:
return [
{transformation: {"column_name": column_name}}
for column_name, transformation in column_specs.items()
]
else:
return None

0 comments on commit 49673fe

Please sign in to comment.