diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 793b70d563..70b35329d1 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -20,13 +20,13 @@ import functools import inspect import threading -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Type, Union import proto from google.auth import credentials as auth_credentials -from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils class FutureManager(metaclass=abc.ABCMeta): @@ -232,8 +232,8 @@ class AiPlatformResourceNoun(metaclass=abc.ABCMeta): @property @classmethod @abc.abstractmethod - def client_class(cls) -> utils.AiPlatformServiceClient: - """Client class required to interact with resource.""" + def client_class(cls) -> Type[utils.AiPlatformServiceClientWithOverride]: + """Client class required to interact with resource with optional overrides.""" pass @property @@ -287,7 +287,7 @@ def _instantiate_client( cls, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> utils.AiPlatformServiceClient: + ) -> utils.AiPlatformServiceClientWithOverride: """Helper method to instantiate service client for resource noun. Args: @@ -296,8 +296,8 @@ def _instantiate_client( Optional custom credentials to use when accessing interacting with resource noun. Returns: - client (utils.AiPlatformServiceClient): - Initialized service client for this service noun. + client (utils.AiPlatformServiceClientWithOverride): + Initialized service client for this service noun with optional overrides. """ return initializer.global_config.create_client( client_class=cls.client_class, diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py new file mode 100644 index 0000000000..36d805c6cb --- /dev/null +++ b/google/cloud/aiplatform/compat/__init__.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat import services +from google.cloud.aiplatform.compat import types + +V1BETA1 = "v1beta1" +V1 = "v1" + +DEFAULT_VERSION = V1 + +if DEFAULT_VERSION == V1BETA1: + + services.dataset_service_client = services.dataset_service_client_v1beta1 + services.endpoint_service_client = services.endpoint_service_client_v1beta1 + services.job_service_client = services.job_service_client_v1beta1 + services.model_service_client = services.model_service_client_v1beta1 + services.pipeline_service_client = services.pipeline_service_client_v1beta1 + services.prediction_service_client = services.prediction_service_client_v1beta1 + services.specialist_pool_service_client = ( + services.specialist_pool_service_client_v1beta1 + ) + + types.accelerator_type = types.accelerator_type_v1beta1 + types.annotation = types.annotation_v1beta1 + types.annotation_spec = types.annotation_spec_v1beta1 + types.batch_prediction_job = types.batch_prediction_job_v1beta1 + types.completion_stats = types.completion_stats_v1beta1 + types.custom_job = types.custom_job_v1beta1 + types.data_item = types.data_item_v1beta1 + types.data_labeling_job = types.data_labeling_job_v1beta1 + types.dataset = types.dataset_v1beta1 + types.dataset_service = types.dataset_service_v1beta1 + types.deployed_model_ref = types.deployed_model_ref_v1beta1 + types.encryption_spec = types.encryption_spec_v1beta1 + types.endpoint = types.endpoint_v1beta1 + types.endpoint_service = types.endpoint_service_v1beta1 + types.env_var = types.env_var_v1beta1 + types.explanation = types.explanation_v1beta1 + types.explanation_metadata = types.explanation_metadata_v1beta1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1 + types.io = types.io_v1beta1 + types.job_service = types.job_service_v1beta1 + types.job_state = types.job_state_v1beta1 + types.machine_resources = types.machine_resources_v1beta1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1 + types.model = types.model_v1beta1 + types.model_evaluation = types.model_evaluation_v1beta1 + types.model_evaluation_slice = types.model_evaluation_slice_v1beta1 + types.model_service = types.model_service_v1beta1 + types.operation = types.operation_v1beta1 + types.pipeline_service = types.pipeline_service_v1beta1 + types.pipeline_state = types.pipeline_state_v1beta1 + types.prediction_service = types.prediction_service_v1beta1 + types.specialist_pool = types.specialist_pool_v1beta1 + types.specialist_pool_service = types.specialist_pool_service_v1beta1 + types.training_pipeline = types.training_pipeline_v1beta1 + +if DEFAULT_VERSION == V1: + + services.dataset_service_client = services.dataset_service_client_v1 + services.endpoint_service_client = services.endpoint_service_client_v1 + services.job_service_client = services.job_service_client_v1 + services.model_service_client = services.model_service_client_v1 + services.pipeline_service_client = services.pipeline_service_client_v1 + services.prediction_service_client = services.prediction_service_client_v1 + services.specialist_pool_service_client = services.specialist_pool_service_client_v1 + + types.accelerator_type = types.accelerator_type_v1 + types.annotation = types.annotation_v1 + types.annotation_spec = types.annotation_spec_v1 + types.batch_prediction_job = types.batch_prediction_job_v1 + types.completion_stats = types.completion_stats_v1 + types.custom_job = types.custom_job_v1 + types.data_item = types.data_item_v1 + types.data_labeling_job = types.data_labeling_job_v1 + types.dataset = types.dataset_v1 + types.dataset_service = types.dataset_service_v1 + types.deployed_model_ref = types.deployed_model_ref_v1 + types.encryption_spec = types.encryption_spec_v1 + types.endpoint = types.endpoint_v1 + types.endpoint_service = types.endpoint_service_v1 + types.env_var = types.env_var_v1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1 + types.io = types.io_v1 + types.job_service = types.job_service_v1 + types.job_state = types.job_state_v1 + types.machine_resources = types.machine_resources_v1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1 + types.model = types.model_v1 + types.model_evaluation = types.model_evaluation_v1 + types.model_evaluation_slice = types.model_evaluation_slice_v1 + types.model_service = types.model_service_v1 + types.operation = types.operation_v1 + types.pipeline_service = types.pipeline_service_v1 + types.pipeline_state = types.pipeline_state_v1 + types.prediction_service = types.prediction_service_v1 + types.specialist_pool = types.specialist_pool_v1 + types.specialist_pool_service = types.specialist_pool_service_v1 + types.training_pipeline = types.training_pipeline_v1 + +__all__ = ( + DEFAULT_VERSION, + V1BETA1, + V1, + services, + types, +) diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py new file mode 100644 index 0000000000..0888c27fbb --- /dev/null +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + client as dataset_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + client as pipeline_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client_v1, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client_v1, +) +from google.cloud.aiplatform_v1.services.job_service import ( + client as job_service_client_v1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client_v1, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( + client as prediction_service_client_v1, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1, +) + +__all__ = ( + # v1 + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, + specialist_pool_service_client_v1, + # v1beta1 + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, + specialist_pool_service_client_v1beta1, +) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py new file mode 100644 index 0000000000..d03e0d2f3a --- /dev/null +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as accelerator_type_v1beta1, + annotation as annotation_v1beta1, + annotation_spec as annotation_spec_v1beta1, + batch_prediction_job as batch_prediction_job_v1beta1, + completion_stats as completion_stats_v1beta1, + custom_job as custom_job_v1beta1, + data_item as data_item_v1beta1, + data_labeling_job as data_labeling_job_v1beta1, + dataset as dataset_v1beta1, + dataset_service as dataset_service_v1beta1, + deployed_model_ref as deployed_model_ref_v1beta1, + encryption_spec as encryption_spec_v1beta1, + endpoint as endpoint_v1beta1, + endpoint_service as endpoint_service_v1beta1, + env_var as env_var_v1beta1, + explanation as explanation_v1beta1, + explanation_metadata as explanation_metadata_v1beta1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1, + io as io_v1beta1, + job_service as job_service_v1beta1, + job_state as job_state_v1beta1, + machine_resources as machine_resources_v1beta1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1, + model as model_v1beta1, + model_evaluation as model_evaluation_v1beta1, + model_evaluation_slice as model_evaluation_slice_v1beta1, + model_service as model_service_v1beta1, + operation as operation_v1beta1, + pipeline_service as pipeline_service_v1beta1, + pipeline_state as pipeline_state_v1beta1, + prediction_service as prediction_service_v1beta1, + specialist_pool as specialist_pool_v1beta1, + specialist_pool_service as specialist_pool_service_v1beta1, + training_pipeline as training_pipeline_v1beta1, +) +from google.cloud.aiplatform_v1.types import ( + accelerator_type as accelerator_type_v1, + annotation as annotation_v1, + annotation_spec as annotation_spec_v1, + batch_prediction_job as batch_prediction_job_v1, + completion_stats as completion_stats_v1, + custom_job as custom_job_v1, + data_item as data_item_v1, + data_labeling_job as data_labeling_job_v1, + dataset as dataset_v1, + dataset_service as dataset_service_v1, + deployed_model_ref as deployed_model_ref_v1, + encryption_spec as encryption_spec_v1, + endpoint as endpoint_v1, + endpoint_service as endpoint_service_v1, + env_var as env_var_v1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1, + io as io_v1, + job_service as job_service_v1, + job_state as job_state_v1, + machine_resources as machine_resources_v1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1, + model as model_v1, + model_evaluation as model_evaluation_v1, + model_evaluation_slice as model_evaluation_slice_v1, + model_service as model_service_v1, + operation as operation_v1, + pipeline_service as pipeline_service_v1, + pipeline_state as pipeline_state_v1, + prediction_service as prediction_service_v1, + specialist_pool as specialist_pool_v1, + specialist_pool_service as specialist_pool_service_v1, + training_pipeline as training_pipeline_v1, +) + +__all__ = ( + # v1 + accelerator_type_v1, + annotation_v1, + annotation_spec_v1, + batch_prediction_job_v1, + completion_stats_v1, + custom_job_v1, + data_item_v1, + data_labeling_job_v1, + dataset_v1, + dataset_service_v1, + deployed_model_ref_v1, + encryption_spec_v1, + endpoint_v1, + endpoint_service_v1, + env_var_v1, + hyperparameter_tuning_job_v1, + io_v1, + job_service_v1, + job_state_v1, + machine_resources_v1, + manual_batch_tuning_parameters_v1, + model_v1, + model_evaluation_v1, + model_evaluation_slice_v1, + model_service_v1, + operation_v1, + pipeline_service_v1, + pipeline_state_v1, + prediction_service_v1, + specialist_pool_v1, + specialist_pool_service_v1, + training_pipeline_v1, + # v1beta1 + accelerator_type_v1beta1, + annotation_v1beta1, + annotation_spec_v1beta1, + batch_prediction_job_v1beta1, + completion_stats_v1beta1, + custom_job_v1beta1, + data_item_v1beta1, + data_labeling_job_v1beta1, + dataset_v1beta1, + dataset_service_v1beta1, + deployed_model_ref_v1beta1, + encryption_spec_v1beta1, + endpoint_v1beta1, + endpoint_service_v1beta1, + env_var_v1beta1, + explanation_v1beta1, + explanation_metadata_v1beta1, + hyperparameter_tuning_job_v1beta1, + io_v1beta1, + job_service_v1beta1, + job_state_v1beta1, + machine_resources_v1beta1, + manual_batch_tuning_parameters_v1beta1, + model_v1beta1, + model_evaluation_v1beta1, + model_evaluation_slice_v1beta1, + model_service_v1beta1, + operation_v1beta1, + pipeline_service_v1beta1, + pipeline_state_v1beta1, + prediction_service_v1beta1, + specialist_pool_v1beta1, + specialist_pool_service_v1beta1, + training_pipeline_v1beta1, +) diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index 8b156c4c9e..eefd1b04fd 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -17,11 +17,13 @@ import abc from typing import Optional, Dict, Sequence, Union -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset - from google.cloud.aiplatform import schema +from google.cloud.aiplatform.compat.types import ( + io as gca_io, + dataset as gca_dataset, +) + class Datasource(abc.ABC): """An abstract class that sets dataset_metadata""" diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 872f736279..207c4a6f8d 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -24,19 +24,19 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils -from google.cloud.aiplatform.datasets import _datasources -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - client as dataset_service_client, +from google.cloud.aiplatform.compat.services import dataset_service_client +from google.cloud.aiplatform.compat.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, ) +from google.cloud.aiplatform.datasets import _datasources class Dataset(base.AiPlatformResourceNounWithFutureManager): """Managed dataset resource for AI Platform""" - client_class = dataset_service_client.DatasetServiceClient + client_class = utils.DatasetClientWithOverride _is_client_prediction_client = False _resource_noun = "datasets" _getter_method = "get_dataset" diff --git a/google/cloud/aiplatform/explain/__init__.py b/google/cloud/aiplatform/explain/__init__.py index 157ae37e1b..61b9181834 100644 --- a/google/cloud/aiplatform/explain/__init__.py +++ b/google/cloud/aiplatform/explain/__init__.py @@ -15,29 +15,33 @@ # limitations under the License. # -from google.cloud.aiplatform_v1beta1.types.explanation_metadata import ( - ExplanationMetadata, +from google.cloud.aiplatform.compat.types import ( + explanation_metadata_v1beta1 as explanation_metadata, + explanation_v1beta1 as explanation, ) -from google.cloud.aiplatform_v1beta1.types.explanation import ExplanationParameters -from google.cloud.aiplatform_v1beta1.types.explanation import FeatureNoiseSigma - -# Classes used by ExplanationParameters -from google.cloud.aiplatform_v1beta1.types.explanation import ( - IntegratedGradientsAttribution, -) -from google.cloud.aiplatform_v1beta1.types.explanation import SampledShapleyAttribution -from google.cloud.aiplatform_v1beta1.types.explanation import SmoothGradConfig -from google.cloud.aiplatform_v1beta1.types.explanation import XraiAttribution +ExplanationMetadata = explanation_metadata.ExplanationMetadata # ExplanationMetadata subclasses InputMetadata = ExplanationMetadata.InputMetadata OutputMetadata = ExplanationMetadata.OutputMetadata # InputMetadata subclasses -Encoding = ExplanationMetadata.InputMetadata.Encoding -FeatureValueDomain = ExplanationMetadata.InputMetadata.FeatureValueDomain -Visualization = ExplanationMetadata.InputMetadata.Visualization +Encoding = InputMetadata.Encoding +FeatureValueDomain = InputMetadata.FeatureValueDomain +Visualization = InputMetadata.Visualization + + +ExplanationParameters = explanation.ExplanationParameters +FeatureNoiseSigma = explanation.FeatureNoiseSigma + +# Classes used by ExplanationParameters +IntegratedGradientsAttribution = explanation.IntegratedGradientsAttribution + +SampledShapleyAttribution = explanation.SampledShapleyAttribution +SmoothGradConfig = explanation.SmoothGradConfig +XraiAttribution = explanation.XraiAttribution + __all__ = ( "Encoding", diff --git a/google/cloud/aiplatform/helpers/_decorators.py b/google/cloud/aiplatform/helpers/_decorators.py index 5d9aa28bea..95aac31c4f 100644 --- a/google/cloud/aiplatform/helpers/_decorators.py +++ b/google/cloud/aiplatform/helpers/_decorators.py @@ -68,3 +68,5 @@ def _from_map(map_): marshal = Marshal(name="google.cloud.aiplatform.v1beta1") marshal.register(Value, ConversionValueRule(marshal=marshal)) +marshal = Marshal(name="google.cloud.aiplatform.v1") +marshal.register(Value, ConversionValueRule(marshal=marshal)) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index f544df2a7a..b84a006d02 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -27,9 +27,16 @@ import google.auth from google.auth import credentials as auth_credentials from google.auth.exceptions import GoogleAuthError -from google.cloud.aiplatform import utils + +from google.cloud.aiplatform import compat from google.cloud.aiplatform import constants -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec_compat, + encryption_spec_v1 as gca_encryption_spec_v1, + encryption_spec_v1beta1 as gca_encryption_spec_v1beta1, +) class _Config: @@ -91,17 +98,28 @@ def init( self._encryption_spec_key_name = encryption_spec_key_name def get_encryption_spec( - self, encryption_spec_key_name: Optional[str] - ) -> Optional[gca_encryption_spec.EncryptionSpec]: + self, + encryption_spec_key_name: Optional[str], + select_version: Optional[str] = compat.DEFAULT_VERSION, + ) -> Optional[ + Union[ + gca_encryption_spec_v1.EncryptionSpec, + gca_encryption_spec_v1beta1.EncryptionSpec, + ] + ]: """Creates a gca_encryption_spec.EncryptionSpec instance from the given key name. If the provided key name is None, it uses the default key name if provided. Args: encryption_spec_key_name (Optional[str]): The default encryption key name to use when creating resources. + select_version: The default version is set to compat.DEFAULT_VERSION """ kms_key_name = encryption_spec_key_name or self.encryption_spec_key_name encryption_spec = None if kms_key_name: + gca_encryption_spec = gca_encryption_spec_compat + if select_version == compat.V1BETA1: + gca_encryption_spec = gca_encryption_spec_v1beta1 encryption_spec = gca_encryption_spec.EncryptionSpec( kms_key_name=kms_key_name ) @@ -218,22 +236,22 @@ def common_location_path( def create_client( self, - client_class: Type[utils.AiPlatformServiceClient], + client_class: Type[utils.AiPlatformServiceClientWithOverride], credentials: Optional[auth_credentials.Credentials] = None, location_override: Optional[str] = None, prediction_client: bool = False, - ) -> Union[utils.WrappedClient, utils.AiPlatformServiceClient]: + ) -> utils.AiPlatformServiceClientWithOverride: """Instantiates a given AiPlatformServiceClient with optional overrides. Args: - client_class (utils.AiPlatformServiceClient): - (Required)An AI Platform Service Client. + client_class (utils.AiPlatformServiceClientWithOverride): + (Required) An AI Platform Service Client with optional overrides. credentials (auth_credentials.Credentials): Custom auth credentials. If not provided will use the current config. location_override (str): Optional location override. prediction_client (str): Optional flag to use a prediction endpoint. Returns: - client: Instantiated AI Platform Service client + client: Instantiated AI Platform Service client with optional overrides """ gapic_version = pkg_resources.get_distribution( "google-cloud-aiplatform", @@ -250,11 +268,7 @@ def create_client( "client_info": client_info, } - if prediction_client: - return client_class(**kwargs) - else: - kwargs["client_class"] = client_class - return utils.WrappedClient(**kwargs) + return client_class(**kwargs) # global config to store init parameters: ie, aiplatform.init(project=..., location=...) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 7315a6f662..4f6fd6d094 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -30,20 +30,22 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import compat from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils -from google.cloud.aiplatform_v1beta1.services.job_service import ( - client as job_service_client, +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + io_v1beta1 as gca_io_v1beta1, + job_state as gca_job_state, + batch_prediction_job as gca_bp_job_compat, + batch_prediction_job_v1 as gca_bp_job_v1, + batch_prediction_job_v1beta1 as gca_bp_job_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, ) -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import job_state as gca_job_state -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_bp_job -from google.cloud.aiplatform_v1beta1.types import ( - machine_resources as gca_machine_resources, -) - -from google.cloud.aiplatform_v1beta1.types import explanation as gca_explanation logging.basicConfig(level=logging.INFO, stream=sys.stdout) _LOGGER = logging.getLogger(__name__) @@ -77,7 +79,7 @@ class _Job(base.AiPlatformResourceNounWithFutureManager): _delete_method (str): The name of the specific JobServiceClient delete method """ - client_class = job_service_client.JobServiceClient + client_class = utils.JobpointClientWithOverride _is_client_prediction_client = False def __init__( @@ -434,6 +436,15 @@ def create( f"{predictions_format} is not an accepted prediction format " f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" ) + gca_bp_job = gca_bp_job_compat + gca_io = gca_io_compat + gca_machine_resources = gca_machine_resources_compat + select_version = compat.DEFAULT_VERSION + if generate_explanation: + gca_bp_job = gca_bp_job_v1beta1 + gca_io = gca_io_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 + select_version = compat.V1BETA1 gapic_batch_prediction_job = gca_bp_job.BatchPredictionJob() @@ -475,7 +486,8 @@ def create( # Optional Fields gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec( - encryption_spec_key_name=encryption_spec_key_name + encryption_spec_key_name=encryption_spec_key_name, + select_version=select_version, ) if model_parameters: @@ -507,7 +519,7 @@ def create( gapic_batch_prediction_job.generate_explanation = generate_explanation if explanation_metadata or explanation_parameters: - gapic_batch_prediction_job.explanation_spec = gca_explanation.ExplanationSpec( + gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec( metadata=explanation_metadata, parameters=explanation_parameters ) @@ -521,6 +533,7 @@ def create( project=project, location=location ), batch_prediction_job=gapic_batch_prediction_job, + generate_explanation=generate_explanation, project=project or initializer.global_config.project, location=location or initializer.global_config.location, credentials=credentials or initializer.global_config.credentials, @@ -533,7 +546,10 @@ def _create( cls, api_client: job_service_client.JobServiceClient, parent: str, - batch_prediction_job: gca_bp_job.BatchPredictionJob, + batch_prediction_job: Union[ + gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob + ], + generate_explanation: bool, project: str, location: str, credentials: Optional[auth_credentials.Credentials], @@ -547,6 +563,9 @@ def _create( already set based on user's preferences. batch_prediction_job (gca_bp_job.BatchPredictionJob): Required. a batch prediction job proto for creating a batch prediction job on AI Platform. + generate_explanation (bool): + Required. Generate explanation along with the batch prediction + results. parent (str): Required. Also known as common location path, that usually contains the project and location that the user provided to the upstream method. @@ -572,6 +591,10 @@ def _create( by AI Platform. """ + # select v1beta1 if explain else use default v1 + if generate_explanation: + api_client = api_client.select_version(compat.V1BETA1) + gca_batch_prediction_job = api_client.create_batch_prediction_job( parent=parent, batch_prediction_job=batch_prediction_job ) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index ae1ac51bfd..b19ace6d74 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -18,28 +18,29 @@ from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from google.auth import credentials as auth_credentials -from google.cloud import aiplatform + from google.cloud.aiplatform import base +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import explain from google.cloud.aiplatform import initializer -from google.cloud.aiplatform import utils from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import utils -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - client as endpoint_service_client, -) -from google.cloud.aiplatform_v1beta1.services.model_service import ( - client as model_service_client, +from google.cloud.aiplatform.compat.services import endpoint_service_client + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec, + endpoint as gca_endpoint_compat, + endpoint_v1 as gca_endpoint_v1, + endpoint_v1beta1 as gca_endpoint_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + model as gca_model_compat, + model_v1beta1 as gca_model_v1beta1, + env_var as gca_env_var_compat, + env_var_v1beta1 as gca_env_var_v1beta1, ) -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - client as prediction_service_client, -) - -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import explanation as gca_explanation -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import env_var from google.protobuf import json_format @@ -62,12 +63,12 @@ class Prediction(NamedTuple): predictions: Dict[str, List] deployed_model_id: str - explanations: Optional[Sequence[gca_explanation.Explanation]] = None + explanations: Optional[Sequence[gca_explanation_v1beta1.Explanation]] = None class Endpoint(base.AiPlatformResourceNounWithFutureManager): - client_class = endpoint_service_client.EndpointServiceClient + client_class = utils.EndpointClientWithOverride _is_client_prediction_client = False _resource_noun = "endpoints" _getter_method = "get_endpoint" @@ -257,7 +258,7 @@ def _create( project=project, location=location ) - gapic_endpoint = gca_endpoint.Endpoint( + gapic_endpoint = gca_endpoint_compat.Endpoint( display_name=display_name, description=description, labels=labels, @@ -362,6 +363,8 @@ def _validate_deploy_args( deployed_model_display_name: Optional[str], traffic_split: Optional[Dict[str, int]], traffic_percentage: int, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, ): """Helper method to validate deploy arguments. @@ -404,11 +407,21 @@ def _validate_deploy_args( not be provided. Traffic of previously deployed models at the endpoint will be scaled down to accommodate new deployed model's traffic. Should not be provided if traffic_split is provided. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` Raises: ValueError if Min or Max replica is negative. Traffic percentage > 100 or < 0. Or if traffic_split does not sum to 100. + ValueError if either explanation_metadata or explanation_parameters + but not both are specified. """ if min_replica_count < 0: raise ValueError("Min replica cannot be negative.") @@ -430,6 +443,11 @@ def _validate_deploy_args( "Sum of all traffic within traffic split needs to be 100." ) + if bool(explanation_metadata) != bool(explanation_parameters): + raise ValueError( + "Both `explanation_metadata` and `explanation_parameters` should be specified or None." + ) + # Raises ValueError if invalid accelerator if accelerator_type: utils.validate_accelerator_type(accelerator_type) @@ -445,10 +463,8 @@ def deploy( max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: @@ -501,12 +517,12 @@ def deploy( NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` metadata (Sequence[Tuple[str, str]]): @@ -525,6 +541,8 @@ def deploy( deployed_model_display_name, traffic_split, traffic_percentage, + explanation_metadata, + explanation_parameters, ) self._deploy( @@ -555,10 +573,8 @@ def _deploy( max_replica_count: Optional[int] = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: @@ -611,12 +627,12 @@ def _deploy( NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` metadata (Sequence[Tuple[str, str]]): @@ -666,10 +682,8 @@ def _deploy_call( max_replica_count: Optional[int] = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), ): """Helper method to deploy model to endpoint. @@ -720,12 +734,12 @@ def _deploy_call( is not provided, the larger value of min_replica_count or 1 will be used. If value provided is smaller than min_replica_count, it will automatically be increased to be min_replica_count. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` metadata (Sequence[Tuple[str, str]]): @@ -748,20 +762,22 @@ def _deploy_call( raise ValueError( "Both `accelerator_type` and `accelerator_count` should be specified or None." ) - if bool(explanation_metadata) != bool(explanation_parameters): - raise ValueError( - "Both `explanation_metadata` and `explanation_parameters` should be specified or None." - ) + + gca_endpoint = gca_endpoint_compat + gca_machine_resources = gca_machine_resources_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 if machine_type: - machine_spec = machine_resources.MachineSpec(machine_type=machine_type) + machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type) if accelerator_type and accelerator_count: utils.validate_accelerator_type(accelerator_type) machine_spec.accelerator_type = accelerator_type machine_spec.accelerator_count = accelerator_count - dedicated_resources = machine_resources.DedicatedResources( + dedicated_resources = gca_machine_resources.DedicatedResources( machine_spec=machine_spec, min_replica_count=min_replica_count, max_replica_count=max_replica_count, @@ -772,7 +788,7 @@ def _deploy_call( display_name=deployed_model_display_name, ) else: - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=min_replica_count, max_replica_count=max_replica_count, ) @@ -784,6 +800,7 @@ def _deploy_call( # Service will throw error if both metadata and parameters are not provided if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) explanation_spec = gca_endpoint.explanation.ExplanationSpec() explanation_spec.metadata = explanation_metadata explanation_spec.parameters = explanation_parameters @@ -914,8 +931,9 @@ def _undeploy( def _instantiate_prediction_client( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> prediction_service_client.PredictionServiceClient: - """Helper method to instantiates prediction client for this endpoint. + ) -> utils.PredictionClientWithOverride: + + """Helper method to instantiates prediction client with optional overrides for this endpoint. Args: location (str): The location of this endpoint. @@ -924,10 +942,10 @@ def _instantiate_prediction_client( the prediction client. Returns: prediction_client (prediction_service_client.PredictionServiceClient): - Initalized prediction client. + Initalized prediction client with optional overrides. """ return initializer.global_config.create_client( - client_class=prediction_service_client.PredictionServiceClient, + client_class=utils.PredictionClientWithOverride, credentials=credentials, location_override=location, prediction_client=True, @@ -1014,7 +1032,9 @@ def explain( """ self.wait() - explain_response = self._prediction_client.explain( + explain_response = self._prediction_client.select_version( + compat.V1BETA1 + ).explain( endpoint=self.resource_name, instances=instances, parameters=parameters, @@ -1030,7 +1050,11 @@ def explain( explanations=explain_response.explanations, ) - def list_models(self) -> Sequence[gca_endpoint.DeployedModel]: + def list_models( + self, + ) -> Sequence[ + Union[gca_endpoint_v1.DeployedModel, gca_endpoint_v1beta1.DeployedModel] + ]: """Returns a list of the models deployed to this Endpoint. Returns: @@ -1079,7 +1103,7 @@ def delete(self, force: bool = False, sync: bool = True) -> None: class Model(base.AiPlatformResourceNounWithFutureManager): - client_class = model_service_client.ModelServiceClient + client_class = utils.ModelClientWithOverride _is_client_prediction_client = False _resource_noun = "models" _getter_method = "get_model" @@ -1143,10 +1167,8 @@ def upload( instance_schema_uri: Optional[str] = None, parameters_schema_uri: Optional[str] = None, prediction_schema_uri: Optional[str] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -1255,12 +1277,12 @@ def upload( and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` project: Optional[str]=None, @@ -1296,13 +1318,21 @@ def upload( "Both `explanation_metadata` and `explanation_parameters` should be specified or None." ) + gca_endpoint = gca_endpoint_compat + gca_model = gca_model_compat + gca_env_var = gca_env_var_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_model = gca_model_v1beta1 + gca_env_var = gca_env_var_v1beta1 + api_client = cls._instantiate_client(location, credentials) env = None ports = None if serving_container_environment_variables: env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in serving_container_environment_variables.items() ] if serving_container_ports: @@ -1330,7 +1360,7 @@ def upload( # TODO(b/182388545) initializer.global_config.get_encryption_spec from a sync function encryption_spec = initializer.global_config.get_encryption_spec( - encryption_spec_key_name=encryption_spec_key_name + encryption_spec_key_name=encryption_spec_key_name, ) managed_model = gca_model.Model( @@ -1346,6 +1376,7 @@ def upload( # Override explanation_spec if both required fields are provided if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) explanation_spec = gca_endpoint.explanation.ExplanationSpec() explanation_spec.metadata = explanation_metadata explanation_spec.parameters = explanation_parameters @@ -1355,7 +1386,6 @@ def upload( parent=initializer.global_config.common_location_path(project, location), model=managed_model, ) - managed_model = lro.result() fields = utils.extract_fields_from_resource_name(managed_model.model) return cls( @@ -1374,10 +1404,8 @@ def deploy( max_replica_count: Optional[int] = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, sync=True, @@ -1431,12 +1459,12 @@ def deploy( NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` metadata (Sequence[Tuple[str, str]]): @@ -1470,6 +1498,8 @@ def deploy( deployed_model_display_name, traffic_split, traffic_percentage, + explanation_metadata, + explanation_parameters, ) return self._deploy( @@ -1502,10 +1532,8 @@ def _deploy( max_replica_count: Optional[int] = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, sync: bool = True, @@ -1559,12 +1587,12 @@ def _deploy( NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be passed together when used. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. For more details, see `Ref docs ` metadata (Sequence[Tuple[str, str]]): @@ -1638,10 +1666,8 @@ def batch_predict( starting_replica_count: Optional[int] = None, max_replica_count: Optional[int] = None, generate_explanation: Optional[bool] = False, - explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, - explanation_parameters: Optional[ - "aiplatform.explain.ExplanationParameters" - ] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, labels: Optional[dict] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, @@ -1758,7 +1784,7 @@ def batch_predict( keyed `explanation`. The value of the entry is a JSON object that conforms to the [aiplatform.gapic.Explanation] object. - `csv`: Generating explanations for CSV format is not supported. - explanation_metadata (aiplatform.explain.ExplanationMetadata): + explanation_metadata (explain.ExplanationMetadata): Optional. Explanation metadata configuration for this BatchPredictionJob. Can be specified only if `generate_explanation` is set to `True`. @@ -1767,7 +1793,7 @@ def batch_predict( a field of the `explanation_metadata` object is not populated, the corresponding field of the `Model.explanation_metadata` object is inherited. For more details, see `Ref docs ` - explanation_parameters (aiplatform.explain.ExplanationParameters): + explanation_parameters (explain.ExplanationParameters): Optional. Parameters to configure explaining for Model's predictions. Can be specified only if `generate_explanation` is set to `True`. diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 6046291e68..aa7103103e 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -30,28 +30,24 @@ from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base +from google.cloud.aiplatform import constants from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models from google.cloud.aiplatform import schema -from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - client as pipeline_service_client, -) -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import ( + +from google.cloud.aiplatform.compat.types import ( accelerator_type as gca_accelerator_type, -) -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform.v1beta1.schema.trainingjob import ( - definition_v1beta1 as training_job_inputs, +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, ) from google.cloud import storage @@ -73,7 +69,8 @@ class _TrainingJob(base.AiPlatformResourceNounWithFutureManager): - client_class = pipeline_service_client.PipelineServiceClient + + client_class = utils.PipelineClientWithOverride _is_client_prediction_client = False _resource_noun = "trainingPipelines" _getter_method = "get_training_pipeline" @@ -1335,7 +1332,7 @@ def __init__( if model_serving_container_environment_variables: env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in model_serving_container_environment_variables.items() ] diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index df2b3905d3..ec39038942 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -17,6 +17,7 @@ import re +import abc import logging from typing import Any, Match, Optional, Type, TypeVar, Tuple @@ -25,38 +26,47 @@ from google.api_core import client_options from google.api_core import gapic_v1 from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import compat from google.cloud.aiplatform import constants from google.cloud.aiplatform import initializer -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - client as dataset_client, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - client as endpoint_client, -) -from google.cloud.aiplatform_v1beta1.services.job_service import ( - client as job_service_client, -) -from google.cloud.aiplatform_v1beta1.services.model_service import ( - client as model_client, + +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - client as pipeline_service_client, +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, ) -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - client as prediction_client, + +from google.cloud.aiplatform.compat.types import ( + accelerator_type as gca_accelerator_type, ) AiPlatformServiceClient = TypeVar( "AiPlatformServiceClient", - dataset_client.DatasetServiceClient, - endpoint_client.EndpointServiceClient, - model_client.ModelServiceClient, - prediction_client.PredictionServiceClient, - pipeline_service_client.PipelineServiceClient, - job_service_client.JobServiceClient, + # v1beta1 + dataset_service_client_v1beta1.DatasetServiceClient, + endpoint_service_client_v1beta1.EndpointServiceClient, + model_service_client_v1beta1.ModelServiceClient, + prediction_service_client_v1beta1.PredictionServiceClient, + pipeline_service_client_v1beta1.PipelineServiceClient, + job_service_client_v1beta1.JobServiceClient, + # v1 + dataset_service_client_v1.DatasetServiceClient, + endpoint_service_client_v1.EndpointServiceClient, + model_service_client_v1.ModelServiceClient, + prediction_service_client_v1.PredictionServiceClient, + pipeline_service_client_v1.PipelineServiceClient, + job_service_client_v1.JobServiceClient, ) # TODO(b/170334193): Add support for resource names with non-integer IDs @@ -296,20 +306,68 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona return (gcs_bucket, gcs_blob_prefix) -class WrappedClient: - """Wrapper class for client that creates client at API invocation time.""" +class ClientWithOverride: + class WrappedClient: + """Wrapper class for client that creates client at API invocation time.""" + + def __init__( + self, + client_class: Type[AiPlatformServiceClient], + client_options: client_options.ClientOptions, + client_info: gapic_v1.client_info.ClientInfo, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Stores parameters needed to instantiate client. + + client_class (AiPlatformServiceClient): + Required. Class of the client to use. + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. + """ + + self._client_class = client_class + self._credentials = credentials + self._client_options = client_options + self._client_info = client_info + + def __getattr__(self, name: str) -> Any: + """Instantiates client and returns attribute of the client.""" + temporary_client = self._client_class( + credentials=self._credentials, + client_options=self._client_options, + client_info=self._client_info, + ) + return getattr(temporary_client, name) + + @property + @abc.abstractmethod + def _is_temporary(self) -> bool: + pass + + @property + @classmethod + @abc.abstractmethod + def _default_version(self) -> str: + pass + + @property + @classmethod + @abc.abstractmethod + def _version_map(self) -> Tuple: + pass def __init__( self, - client_class: Type[AiPlatformServiceClient], client_options: client_options.ClientOptions, client_info: gapic_v1.client_info.ClientInfo, credentials: Optional[auth_credentials.Credentials] = None, ): """Stores parameters needed to instantiate client. - client_class (AiPlatformServiceClient): - Required. Class of the client to use. client_options (client_options.ClientOptions): Required. Client options to pass to client. client_info (gapic_v1.client_info.ClientInfo): @@ -318,19 +376,93 @@ def __init__( Optional. Client credentials to pass to client. """ - self._client_class = client_class - self._credentials = credentials - self._client_options = client_options - self._client_info = client_info + self._clients = { + version: self.WrappedClient( + client_class=client_class, + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + if self._is_temporary + else client_class( + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + for version, client_class in self._version_map + } def __getattr__(self, name: str) -> Any: """Instantiates client and returns attribute of the client.""" - temporary_client = self._client_class( - credentials=self._credentials, - client_options=self._client_options, - client_info=self._client_info, - ) - return getattr(temporary_client, name) + return getattr(self._clients[self._default_version], name) + + def select_version(self, version: str) -> AiPlatformServiceClient: + return self._clients[version] + + +class DatasetClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, dataset_service_client_v1.DatasetServiceClient), + (compat.V1BETA1, dataset_service_client_v1beta1.DatasetServiceClient), + ) + + +class EndpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, endpoint_service_client_v1.EndpointServiceClient), + (compat.V1BETA1, endpoint_service_client_v1beta1.EndpointServiceClient), + ) + + +class JobpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, job_service_client_v1.JobServiceClient), + (compat.V1BETA1, job_service_client_v1beta1.JobServiceClient), + ) + + +class ModelClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, model_service_client_v1.ModelServiceClient), + (compat.V1BETA1, model_service_client_v1beta1.ModelServiceClient), + ) + + +class PipelineClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, pipeline_service_client_v1.PipelineServiceClient), + (compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient), + ) + + +class PredictionClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, prediction_service_client_v1.PredictionServiceClient), + (compat.V1BETA1, prediction_service_client_v1beta1.PredictionServiceClient), + ) + + +AiPlatformServiceClientWithOverride = TypeVar( + "AiPlatformServiceClientWithOverride", + DatasetClientWithOverride, + EndpointClientWithOverride, + JobpointClientWithOverride, + ModelClientWithOverride, + PipelineClientWithOverride, + PredictionClientWithOverride, +) class LoggingWarningFilter(logging.Filter): diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index b9fd33d5c1..481213275f 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ pytest==6.2.2 google-cloud-storage>=1.26.0, <2.0.0dev -google-cloud-aiplatform==0.5.1 +google-cloud-aiplatform==0.6.0 diff --git a/tests/unit/aiplatform/test_automl_image_training_jobs.py b/tests/unit/aiplatform/test_automl_image_training_jobs.py index 27745ba3c7..d85f4f3b97 100644 --- a/tests/unit/aiplatform/test_automl_image_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_image_training_jobs.py @@ -6,26 +6,26 @@ from google.protobuf import struct_pb2 from google.cloud import aiplatform -from google.cloud.aiplatform import schema -from google.cloud.aiplatform import models + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer - +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1.services.model_service import ( +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( +from google.cloud.aiplatform_v1.services.pipeline_service import ( client as pipeline_service_client, ) -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py index fc5ae5a8ef..77435fa7ed 100644 --- a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -3,25 +3,25 @@ from unittest import mock from google.cloud import aiplatform + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer from google.cloud.aiplatform import schema -from google.cloud.aiplatform.training_jobs import AutoMLTabularTrainingJob +from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1.services.model_service import ( +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( +from google.cloud.aiplatform_v1.services.pipeline_service import ( client as pipeline_service_client, ) -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec -from google.cloud.aiplatform_v1beta1 import Dataset as GapicDataset - from google.protobuf import json_format from google.protobuf import struct_pb2 @@ -148,7 +148,7 @@ def mock_dataset_tabular(): ds = mock.MagicMock(datasets.Dataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None - ds._gca_resource = GapicDataset( + ds._gca_resource = gca_dataset.Dataset( display_name=_TEST_DATASET_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, labels={}, @@ -163,7 +163,7 @@ def mock_dataset_nontabular(): ds = mock.MagicMock(datasets.Dataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None - ds._gca_resource = GapicDataset( + ds._gca_resource = gca_dataset.Dataset( display_name=_TEST_DATASET_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, labels={}, @@ -195,7 +195,7 @@ def test_run_call_pipeline_service_create( encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, ) - job = AutoMLTabularTrainingJob( + job = training_jobs.AutoMLTabularTrainingJob( display_name=_TEST_DISPLAY_NAME, optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, @@ -276,7 +276,7 @@ def test_run_call_pipeline_if_no_model_display_name( ): aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) - job = AutoMLTabularTrainingJob( + job = training_jobs.AutoMLTabularTrainingJob( display_name=_TEST_DISPLAY_NAME, optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, @@ -340,7 +340,7 @@ def test_run_called_twice_raises( ): aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) - job = AutoMLTabularTrainingJob( + job = training_jobs.AutoMLTabularTrainingJob( display_name=_TEST_DISPLAY_NAME, optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, @@ -377,7 +377,7 @@ def test_run_raises_if_pipeline_fails( aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) - job = AutoMLTabularTrainingJob( + job = training_jobs.AutoMLTabularTrainingJob( display_name=_TEST_DISPLAY_NAME, optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, @@ -406,7 +406,7 @@ def test_run_raises_if_pipeline_fails( def test_raises_before_run_is_called(self, mock_pipeline_service_create): aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) - job = AutoMLTabularTrainingJob( + job = training_jobs.AutoMLTabularTrainingJob( display_name=_TEST_DISPLAY_NAME, optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, diff --git a/tests/unit/aiplatform/test_automl_text_training_jobs.py b/tests/unit/aiplatform/test_automl_text_training_jobs.py index d77d38d6a0..84726aa6fa 100644 --- a/tests/unit/aiplatform/test_automl_text_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_text_training_jobs.py @@ -3,28 +3,28 @@ from unittest import mock from google.cloud import aiplatform -from google.cloud.aiplatform import schema -from google.cloud.aiplatform import models + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer - +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1.services.model_service import ( +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( +from google.cloud.aiplatform_v1.services.pipeline_service import ( client as pipeline_service_client, ) -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec -from google.cloud.aiplatform.v1beta1.schema.trainingjob import ( - definition_v1beta1 as training_job_inputs, +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, ) _TEST_PROJECT = "test-project" diff --git a/tests/unit/aiplatform/test_automl_video_training_jobs.py b/tests/unit/aiplatform/test_automl_video_training_jobs.py index 7398385cf4..8743c00c9d 100644 --- a/tests/unit/aiplatform/test_automl_video_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_video_training_jobs.py @@ -6,26 +6,26 @@ from google.protobuf import struct_pb2 from google.cloud import aiplatform -from google.cloud.aiplatform import schema -from google.cloud.aiplatform import models + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer - +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1.services.model_service import ( +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( +from google.cloud.aiplatform_v1.services.pipeline_service import ( client as pipeline_service_client, ) -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 60458bcc70..2ac7489d5f 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -28,18 +28,21 @@ from google.auth import credentials as auth_credentials from google.cloud import aiplatform + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer from google.cloud.aiplatform import schema -from google.cloud.aiplatform_v1beta1 import GcsSource -from google.cloud.aiplatform_v1beta1 import GcsDestination -from google.cloud.aiplatform_v1beta1 import ImportDataConfig -from google.cloud.aiplatform_v1beta1 import ExportDataConfig -from google.cloud.aiplatform_v1beta1 import DatasetServiceClient -from google.cloud.aiplatform_v1beta1 import Dataset as GapicDataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client, +) + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + dataset_service as gca_dataset_service, + encryption_spec as gca_encryption_spec, + io as gca_io, +) # project _TEST_PROJECT = "test-project" @@ -109,8 +112,10 @@ @pytest.fixture def get_dataset_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, name=_TEST_NAME, @@ -122,8 +127,10 @@ def get_dataset_mock(): @pytest.fixture def get_dataset_without_name_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, encryption_spec=_TEST_ENCRYPTION_SPEC, @@ -133,8 +140,10 @@ def get_dataset_without_name_mock(): @pytest.fixture def get_dataset_image_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -146,8 +155,10 @@ def get_dataset_image_mock(): @pytest.fixture def get_dataset_tabular_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, metadata=_TEST_METADATA_TABULAR_BQ, @@ -159,8 +170,10 @@ def get_dataset_tabular_mock(): @pytest.fixture def get_dataset_text_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -172,8 +185,10 @@ def get_dataset_text_mock(): @pytest.fixture def get_dataset_video_mock(): - with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock: - get_dataset_mock.return_value = GapicDataset( + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -185,9 +200,11 @@ def get_dataset_video_mock(): @pytest.fixture def create_dataset_mock(): - with patch.object(DatasetServiceClient, "create_dataset") as create_dataset_mock: + with patch.object( + dataset_service_client.DatasetServiceClient, "create_dataset" + ) as create_dataset_mock: create_dataset_lro_mock = mock.Mock(operation.Operation) - create_dataset_lro_mock.result.return_value = GapicDataset( + create_dataset_lro_mock.result.return_value = gca_dataset.Dataset( name=_TEST_NAME, display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, @@ -200,11 +217,11 @@ def create_dataset_mock(): @pytest.fixture def delete_dataset_mock(): with mock.patch.object( - DatasetServiceClient, "delete_dataset" + dataset_service_client.DatasetServiceClient, "delete_dataset" ) as delete_dataset_mock: delete_dataset_lro_mock = mock.Mock(operation.Operation) delete_dataset_lro_mock.result.return_value = ( - dataset_service.DeleteDatasetRequest() + gca_dataset_service.DeleteDatasetRequest() ) delete_dataset_mock.return_value = delete_dataset_lro_mock yield delete_dataset_mock @@ -212,14 +229,18 @@ def delete_dataset_mock(): @pytest.fixture def import_data_mock(): - with patch.object(DatasetServiceClient, "import_data") as import_data_mock: + with patch.object( + dataset_service_client.DatasetServiceClient, "import_data" + ) as import_data_mock: import_data_mock.return_value = mock.Mock(operation.Operation) yield import_data_mock @pytest.fixture def export_data_mock(): - with patch.object(DatasetServiceClient, "export_data") as export_data_mock: + with patch.object( + dataset_service_client.DatasetServiceClient, "export_data" + ) as export_data_mock: export_data_mock.return_value = mock.Mock(operation.Operation) yield export_data_mock @@ -283,7 +304,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_dataset( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -311,7 +332,7 @@ def test_create_dataset_nontabular(self, create_dataset_mock, sync): if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -335,7 +356,7 @@ def test_create_dataset_tabular(self, create_dataset_mock): encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, metadata=_TEST_METADATA_TABULAR_BQ, @@ -368,15 +389,15 @@ def test_create_and_import_dataset( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=_TEST_NONTABULAR_DATASET_METADATA, encryption_spec=_TEST_ENCRYPTION_SPEC, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI, data_item_labels=_TEST_DATA_LABEL_ITEMS, ) @@ -411,8 +432,8 @@ def test_import_data(self, import_data_mock, sync): if not sync: my_dataset.wait() - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI, data_item_labels=_TEST_DATA_LABEL_ITEMS, ) @@ -429,8 +450,8 @@ def test_export_data(self, export_data_mock): my_dataset.export_data(output_dir=_TEST_OUTPUT_DIR) - expected_export_config = ExportDataConfig( - gcs_destination=GcsDestination(output_uri_prefix=_TEST_OUTPUT_DIR) + expected_export_config = gca_dataset.ExportDataConfig( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_DIR) ) export_data_mock.assert_called_once_with( @@ -461,15 +482,15 @@ def test_create_then_import( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=_TEST_NONTABULAR_DATASET_METADATA, encryption_spec=_TEST_ENCRYPTION_SPEC, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI, data_item_labels=_TEST_DATA_LABEL_ITEMS, ) @@ -536,7 +557,7 @@ def test_create_dataset(self, create_dataset_mock, sync): if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -567,7 +588,7 @@ def test_create_and_import_dataset( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -580,8 +601,8 @@ def test_create_and_import_dataset( metadata=_TEST_REQUEST_METADATA, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, ) import_data_mock.assert_called_once_with( @@ -607,8 +628,8 @@ def test_import_data(self, import_data_mock, sync): if not sync: my_dataset.wait() - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, ) @@ -638,7 +659,7 @@ def test_create_then_import( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -652,8 +673,8 @@ def test_create_then_import( get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, ) @@ -700,7 +721,7 @@ def test_create_dataset_with_default_encryption_key( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, metadata=_TEST_METADATA_TABULAR_BQ, @@ -728,7 +749,7 @@ def test_create_dataset(self, create_dataset_mock, sync): if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, metadata=_TEST_METADATA_TABULAR_BQ, @@ -784,7 +805,7 @@ def test_create_dataset(self, create_dataset_mock, sync): if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -815,7 +836,7 @@ def test_create_and_import_dataset( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -828,8 +849,8 @@ def test_create_and_import_dataset( metadata=_TEST_REQUEST_METADATA, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, ) import_data_mock.assert_called_once_with( @@ -857,8 +878,8 @@ def test_import_data(self, import_data_mock, sync): if not sync: my_dataset.wait() - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, ) @@ -888,7 +909,7 @@ def test_create_then_import( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -902,8 +923,8 @@ def test_create_then_import( get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, ) @@ -948,7 +969,7 @@ def test_create_dataset(self, create_dataset_mock, sync): if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -979,7 +1000,7 @@ def test_create_and_import_dataset( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -992,8 +1013,8 @@ def test_create_and_import_dataset( metadata=_TEST_REQUEST_METADATA, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, ) import_data_mock.assert_called_once_with( @@ -1019,8 +1040,8 @@ def test_import_data(self, import_data_mock, sync): if not sync: my_dataset.wait() - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, ) @@ -1050,7 +1071,7 @@ def test_create_then_import( if not sync: my_dataset.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=_TEST_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, metadata=_TEST_NONTABULAR_DATASET_METADATA, @@ -1064,8 +1085,8 @@ def test_create_then_import( get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, ) diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py index d2e09c47db..4937c95e34 100644 --- a/tests/unit/aiplatform/test_end_to_end.py +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -25,16 +25,14 @@ from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1 import GcsSource -from google.cloud.aiplatform_v1beta1 import ImportDataConfig -from google.cloud.aiplatform_v1beta1 import Dataset as GapicDataset -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import EncryptionSpec import test_datasets from test_datasets import create_dataset_mock # noqa: F401 @@ -61,7 +59,9 @@ # dataset_encryption _TEST_ENCRYPTION_KEY_NAME = "key_1234" -_TEST_ENCRYPTION_SPEC = EncryptionSpec(kms_key_name=_TEST_ENCRYPTION_KEY_NAME) +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) class TestEndToEnd: @@ -168,15 +168,15 @@ def test_dataset_create_to_model_predict( parameters={"param": 3.0}, ) - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=test_datasets._TEST_DISPLAY_NAME, metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, encryption_spec=_TEST_ENCRYPTION_SPEC, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, ) @@ -287,7 +287,6 @@ def test_dataset_create_to_model_predict_with_pipeline_fail( self, create_dataset_mock, # noqa: F811 import_data_mock, # noqa: F811 - predict_client_predict_mock, # noqa: F811 mock_python_package_to_gcs, # noqa: F811 mock_pipeline_service_create_and_get_with_fail, # noqa: F811 mock_model_service_get, # noqa: F811 @@ -355,15 +354,15 @@ def test_dataset_create_to_model_predict_with_pipeline_fail( with pytest.raises(RuntimeError): created_endpoint.wait() - expected_dataset = GapicDataset( + expected_dataset = gca_dataset.Dataset( display_name=test_datasets._TEST_DISPLAY_NAME, metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, encryption_spec=_TEST_ENCRYPTION_SPEC, ) - expected_import_config = ImportDataConfig( - gcs_source=GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, ) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 292cc797e1..7b18e1e497 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -24,23 +24,41 @@ from google.auth import credentials as auth_credentials from google.cloud import aiplatform + from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models -from google.cloud.aiplatform_v1beta1.services.model_service.client import ( - ModelServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service.client import ( - EndpointServiceClient, +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, ) from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + prediction_service as gca_prediction_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, +) + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( client as prediction_service_client, ) -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import prediction_service -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import ( + endpoint as gca_endpoint, + model as gca_model, + machine_resources as gca_machine_resources, + prediction_service as gca_prediction_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, +) _TEST_PROJECT = "test-project" _TEST_PROJECT_2 = "test-project-2" @@ -76,10 +94,12 @@ _TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" _TEST_ACCELERATOR_COUNT = 2 -_TEST_EXPLANATIONS = [prediction_service.explanation.Explanation(attributions=[])] +_TEST_EXPLANATIONS = [ + gca_prediction_service_v1beta1.explanation.Explanation(attributions=[]) +] _TEST_ATTRIBUTIONS = [ - prediction_service.explanation.Attribution( + gca_prediction_service_v1beta1.explanation.Attribution( baseline_output_value=1.0, instance_output_value=2.0, feature_attributions=3.0, @@ -120,7 +140,9 @@ @pytest.fixture def get_endpoint_mock(): - with mock.patch.object(EndpointServiceClient, "get_endpoint") as get_endpoint_mock: + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, name=_TEST_ENDPOINT_NAME, @@ -131,7 +153,9 @@ def get_endpoint_mock(): @pytest.fixture def get_endpoint_with_models_mock(): - with mock.patch.object(EndpointServiceClient, "get_endpoint") as get_endpoint_mock: + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, name=_TEST_ENDPOINT_NAME, @@ -142,7 +166,9 @@ def get_endpoint_with_models_mock(): @pytest.fixture def get_model_mock(): - with mock.patch.object(ModelServiceClient, "get_model") as get_model_mock: + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: get_model_mock.return_value = gca_model.Model( display_name=_TEST_DISPLAY_NAME, name=_TEST_MODEL_NAME, ) @@ -152,7 +178,7 @@ def get_model_mock(): @pytest.fixture def create_endpoint_mock(): with mock.patch.object( - EndpointServiceClient, "create_endpoint" + endpoint_service_client.EndpointServiceClient, "create_endpoint" ) as create_endpoint_mock: create_endpoint_lro_mock = mock.Mock(ga_operation.Operation) create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint( @@ -164,12 +190,30 @@ def create_endpoint_mock(): @pytest.fixture def deploy_model_mock(): - with mock.patch.object(EndpointServiceClient, "deploy_model") as deploy_model_mock: + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: deployed_model = gca_endpoint.DeployedModel( model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, ) deploy_model_lro_mock = mock.Mock(ga_operation.Operation) - deploy_model_lro_mock.result.return_value = endpoint_service.DeployModelResponse( + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( deployed_model=deployed_model, ) deploy_model_mock.return_value = deploy_model_lro_mock @@ -179,11 +223,11 @@ def deploy_model_mock(): @pytest.fixture def undeploy_model_mock(): with mock.patch.object( - EndpointServiceClient, "undeploy_model" + endpoint_service_client.EndpointServiceClient, "undeploy_model" ) as undeploy_model_mock: undeploy_model_lro_mock = mock.Mock(ga_operation.Operation) undeploy_model_lro_mock.result.return_value = ( - endpoint_service.UndeployModelResponse() + gca_endpoint_service.UndeployModelResponse() ) undeploy_model_mock.return_value = undeploy_model_lro_mock yield undeploy_model_mock @@ -192,11 +236,11 @@ def undeploy_model_mock(): @pytest.fixture def delete_endpoint_mock(): with mock.patch.object( - EndpointServiceClient, "delete_endpoint" + endpoint_service_client.EndpointServiceClient, "delete_endpoint" ) as delete_endpoint_mock: delete_endpoint_lro_mock = mock.Mock(ga_operation.Operation) delete_endpoint_lro_mock.result.return_value = ( - endpoint_service.DeleteEndpointRequest() + gca_endpoint_service.DeleteEndpointRequest() ) delete_endpoint_mock.return_value = delete_endpoint_lro_mock yield delete_endpoint_mock @@ -223,14 +267,11 @@ def sdk_undeploy_all_mock(): @pytest.fixture def create_client_mock(): with mock.patch.object( - initializer.global_config, "create_client" + initializer.global_config, "create_client", autospec=True, ) as create_client_mock: - - def side_effect(client_class, *arg, **kwargs): - return mock.Mock(spec=client_class) - - create_client_mock.side_effect = side_effect - + create_client_mock.return_value = mock.Mock( + spec=endpoint_service_client.EndpointServiceClient + ) yield create_client_mock @@ -239,7 +280,7 @@ def predict_client_predict_mock(): with mock.patch.object( prediction_service_client.PredictionServiceClient, "predict" ) as predict_mock: - predict_mock.return_value = prediction_service.PredictResponse( + predict_mock.return_value = gca_prediction_service.PredictResponse( deployed_model_id=_TEST_MODEL_ID ) predict_mock.return_value.predictions.extend(_TEST_PREDICTION) @@ -249,9 +290,9 @@ def predict_client_predict_mock(): @pytest.fixture def predict_client_explain_mock(): with mock.patch.object( - prediction_service_client.PredictionServiceClient, "explain" + prediction_service_client_v1beta1.PredictionServiceClient, "explain" ) as predict_mock: - predict_mock.return_value = prediction_service.ExplainResponse( + predict_mock.return_value = gca_prediction_service_v1beta1.ExplainResponse( deployed_model_id=_TEST_MODEL_ID, ) predict_mock.return_value.predictions.extend(_TEST_PREDICTION) @@ -280,13 +321,13 @@ def test_constructor(self, create_client_mock): create_client_mock.assert_has_calls( [ mock.call( - client_class=EndpointServiceClient, + client_class=utils.EndpointClientWithOverride, credentials=initializer.global_config.credentials, location_override=_TEST_LOCATION, prediction_client=False, ), mock.call( - client_class=prediction_service_client.PredictionServiceClient, + client_class=utils.PredictionClientWithOverride, credentials=None, location_override=_TEST_LOCATION, prediction_client=True, @@ -307,7 +348,7 @@ def test_constructor_with_endpoint_name(self, get_endpoint_mock): def test_constructor_with_custom_project(self, get_endpoint_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) - test_endpoint_resource_name = EndpointServiceClient.endpoint_path( + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID ) get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) @@ -315,7 +356,7 @@ def test_constructor_with_custom_project(self, get_endpoint_mock): def test_constructor_with_custom_location(self, get_endpoint_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) - test_endpoint_resource_name = EndpointServiceClient.endpoint_path( + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID ) get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) @@ -328,13 +369,13 @@ def test_constructor_with_custom_credentials(self, create_client_mock): create_client_mock.assert_has_calls( [ mock.call( - client_class=EndpointServiceClient, + client_class=utils.EndpointClientWithOverride, credentials=creds, location_override=_TEST_LOCATION, prediction_client=False, ), mock.call( - client_class=prediction_service_client.PredictionServiceClient, + client_class=utils.PredictionClientWithOverride, credentials=creds, location_override=_TEST_LOCATION, prediction_client=True, @@ -418,7 +459,7 @@ def test_deploy(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -446,7 +487,7 @@ def test_deploy_with_display_name(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -523,7 +564,7 @@ def test_deploy_raise_error_traffic_split(self, sync): def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( - EndpointServiceClient, "get_endpoint" + endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, @@ -536,7 +577,7 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync) if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -556,7 +597,7 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): def test_deploy_with_traffic_split(self, deploy_model_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( - EndpointServiceClient, "get_endpoint" + endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, @@ -572,7 +613,7 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -604,12 +645,12 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - expected_machine_spec = machine_resources.MachineSpec( + expected_machine_spec = gca_machine_resources.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ) - expected_dedicated_resources = machine_resources.DedicatedResources( + expected_dedicated_resources = gca_machine_resources.DedicatedResources( machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1, @@ -628,7 +669,7 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) - def test_deploy_with_explanations(self, deploy_model_mock, sync): + def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) @@ -645,26 +686,26 @@ def test_deploy_with_explanations(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - expected_machine_spec = machine_resources.MachineSpec( + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ) - expected_dedicated_resources = machine_resources.DedicatedResources( + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1, ) - expected_deployed_model = gca_endpoint.DeployedModel( + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( dedicated_resources=expected_dedicated_resources, model=test_model.resource_name, display_name=None, - explanation_spec=gca_endpoint.explanation.ExplanationSpec( + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), ) - deploy_model_mock.assert_called_once_with( + deploy_model_with_explanations_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=expected_deployed_model, traffic_split={"0": 100}, @@ -681,7 +722,7 @@ def test_deploy_with_min_replica_count(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=2, max_replica_count=2, ) deployed_model = gca_endpoint.DeployedModel( @@ -705,7 +746,7 @@ def test_deploy_with_max_replica_count(self, deploy_model_mock, sync): test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync) if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=2, ) deployed_model = gca_endpoint.DeployedModel( @@ -790,7 +831,7 @@ def test_unallocate_traffic(self, model1, model2, model3, deployed_model): def test_undeploy(self, undeploy_model_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( - EndpointServiceClient, "get_endpoint" + endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, @@ -814,7 +855,7 @@ def test_undeploy(self, undeploy_model_mock, sync): def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( - EndpointServiceClient, "get_endpoint" + endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 041a498e38..1d97ad2e9a 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -22,10 +22,12 @@ import google.auth from google.auth import credentials + from google.cloud.aiplatform import initializer from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils -from google.cloud.aiplatform_v1beta1.services.model_service import ( + +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) @@ -97,10 +99,10 @@ def test_common_location_path_overrides(self): def test_create_client_returns_client(self): initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) client = initializer.global_config.create_client( - model_service_client.ModelServiceClient + client_class=utils.ModelClientWithOverride ) assert client._client_class is model_service_client.ModelServiceClient - assert isinstance(client, utils.WrappedClient) + assert isinstance(client, utils.ModelClientWithOverride) assert ( client._transport._host == f"{_TEST_LOCATION}-{constants.API_BASE_PATH}:443" ) @@ -109,12 +111,12 @@ def test_create_client_overrides(self): initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = credentials.AnonymousCredentials() client = initializer.global_config.create_client( - model_service_client.ModelServiceClient, + client_class=utils.ModelClientWithOverride, credentials=creds, location_override=_TEST_LOCATION_2, prediction_client=True, ) - assert isinstance(client, model_service_client.ModelServiceClient) + assert isinstance(client, utils.ModelClientWithOverride) assert ( client._transport._host == f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443" @@ -124,7 +126,7 @@ def test_create_client_overrides(self): def test_create_client_user_agent(self): initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) client = initializer.global_config.create_client( - model_service_client.ModelServiceClient + client_class=utils.ModelClientWithOverride ) for wrapped_method in client._transport._wrapped_methods.values(): diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 255ba05088..53fe9d2d0a 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -27,17 +27,27 @@ from google.auth import credentials as auth_credentials from google.cloud import aiplatform + from google.cloud.aiplatform import jobs from google.cloud.aiplatform import initializer from google.cloud.aiplatform_v1beta1.services.job_service import ( - client as job_service_client, + client as job_service_client_v1beta1, ) -from google.cloud.aiplatform_v1beta1 import types as gapic_types + from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + machine_resources as gca_machine_resources_v1beta1, +) + +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client + +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, io as gca_io, job_state as gca_job_state, - batch_prediction_job as gca_batch_prediction_job, ) _TEST_PROJECT = "test-project" @@ -216,6 +226,19 @@ def create_batch_prediction_job_mock(): yield create_batch_prediction_job_mock +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield create_batch_prediction_job_mock + + @pytest.fixture def get_batch_prediction_job_gcs_output_mock(): with patch.object( @@ -458,7 +481,9 @@ def test_batch_predict_gcs_source_bq_dest( @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_batch_prediction_job_mock") - def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, sync): + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() @@ -486,21 +511,23 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, - input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( instances_format="jsonl", - gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + gcs_source=gca_io_v1beta1.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), ), - output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( - gcs_destination=gca_io.GcsDestination( + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="csv", ), - dedicated_resources=gapic_types.BatchDedicatedResources( - machine_spec=gapic_types.MachineSpec( + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -509,14 +536,14 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn max_replica_count=_TEST_MAX_REPLICA_COUNT, ), generate_explanation=True, - explanation_spec=gapic_types.ExplanationSpec( + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, ) - create_batch_prediction_job_mock.assert_called_once_with( + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", batch_prediction_job=expected_gapic_batch_prediction_job, ) diff --git a/tests/unit/aiplatform/test_lro.py b/tests/unit/aiplatform/test_lro.py index 26685d4f15..0ce2e85594 100644 --- a/tests/unit/aiplatform/test_lro.py +++ b/tests/unit/aiplatform/test_lro.py @@ -17,14 +17,15 @@ from google.api_core import operation + from google.cloud import aiplatform + from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer from google.cloud.aiplatform import lro -from google.cloud.aiplatform_v1beta1.services.model_service.client import ( - ModelServiceClient, -) -from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1.types import model as gca_model from google.longrunning import operations_pb2 from google.protobuf import struct_pb2 as struct @@ -45,7 +46,7 @@ def teardown_module(module): class AiPlatformResourceNounImpl(base.AiPlatformResourceNoun): - client_class = ModelServiceClient + client_class = utils.ModelClientWithOverride _is_client_prediction_client = False _resource_noun = None _getter_method = None diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index c79c98dbd0..ff0310a003 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -22,25 +22,54 @@ from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials + from google.cloud import aiplatform + from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models -from google.cloud.aiplatform_v1beta1.services.model_service.client import ( - ModelServiceClient, +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + env_var as gca_env_var_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + model as gca_model_v1beta1, + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + model_service as gca_model_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, + encryption_spec as gca_encryption_spec_v1beta1, +) + +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.endpoint_service.client import ( - EndpointServiceClient, +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, + io as gca_io, + job_state as gca_job_state, + model as gca_model, + endpoint as gca_endpoint, + machine_resources as gca_machine_resources, + model_service as gca_model_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, ) -from google.cloud.aiplatform_v1beta1.services import job_service -from google.cloud.aiplatform_v1beta1 import types as gapic_types -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import model_service -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec + from test_endpoints import create_endpoint_mock # noqa: F401 @@ -82,7 +111,7 @@ f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" ) _TEST_BATCH_PREDICTION_DISPLAY_NAME = "test-batch-prediction-job" -_TEST_BATCH_PREDICTION_JOB_NAME = job_service.JobServiceClient.batch_prediction_job_path( +_TEST_BATCH_PREDICTION_JOB_NAME = job_service_client.JobServiceClient.batch_prediction_job_path( project=_TEST_PROJECT, location=_TEST_LOCATION, batch_prediction_job=_TEST_ID ) @@ -112,12 +141,27 @@ _TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( kms_key_name=_TEST_ENCRYPTION_KEY_NAME ) +_TEST_ENCRYPTION_SPEC_V1BETA1 = gca_encryption_spec_v1beta1.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_RESOURCE_NAME = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID +) @pytest.fixture def get_endpoint_mock(): - with mock.patch.object(EndpointServiceClient, "get_endpoint") as get_endpoint_mock: - test_endpoint_resource_name = EndpointServiceClient.endpoint_path( + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT, _TEST_LOCATION, _TEST_ID ) get_endpoint_mock.return_value = gca_endpoint.Endpoint( @@ -128,36 +172,141 @@ def get_endpoint_mock(): @pytest.fixture def get_model_mock(): - with mock.patch.object(ModelServiceClient, "get_model") as get_model_mock: - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION, ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: get_model_mock.return_value = gca_model.Model( - display_name=_TEST_MODEL_NAME, name=test_model_resource_name, + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT, ) yield get_model_mock +@pytest.fixture +def upload_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service_v1beta1.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + @pytest.fixture def delete_model_mock(): - with mock.patch.object(ModelServiceClient, "delete_model") as delete_model_mock: + with mock.patch.object( + model_service_client.ModelServiceClient, "delete_model" + ) as delete_model_mock: delete_model_lro_mock = mock.Mock(ga_operation.Operation) - delete_model_lro_mock.result.return_value = model_service.DeleteModelRequest() + delete_model_lro_mock.result.return_value = ( + gca_model_service.DeleteModelRequest() + ) delete_model_mock.return_value = delete_model_lro_mock yield delete_model_mock @pytest.fixture def deploy_model_mock(): - with mock.patch.object(EndpointServiceClient, "deploy_model") as deploy_model_mock: - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID - ) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: deployed_model = gca_endpoint.DeployedModel( - model=test_model_resource_name, display_name=_TEST_MODEL_NAME, + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, ) deploy_model_lro_mock = mock.Mock(ga_operation.Operation) - deploy_model_lro_mock.result.return_value = endpoint_service.DeployModelResponse( + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( deployed_model=deployed_model, ) deploy_model_mock.return_value = deploy_model_lro_mock @@ -167,10 +316,12 @@ def deploy_model_mock(): @pytest.fixture def get_batch_prediction_job_mock(): with mock.patch.object( - job_service.JobServiceClient, "get_batch_prediction_job" + job_service_client.JobServiceClient, "get_batch_prediction_job" ) as get_batch_prediction_job_mock: - batch_prediction_mock = mock.Mock(spec=batch_prediction_job.BatchPredictionJob) - batch_prediction_mock.state = gapic_types.job_state.JobState.JOB_STATE_SUCCEEDED + batch_prediction_mock = mock.Mock( + spec=gca_batch_prediction_job.BatchPredictionJob + ) + batch_prediction_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED batch_prediction_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME get_batch_prediction_job_mock.return_value = batch_prediction_mock yield get_batch_prediction_job_mock @@ -179,16 +330,39 @@ def get_batch_prediction_job_mock(): @pytest.fixture def create_batch_prediction_job_mock(): with mock.patch.object( - job_service.JobServiceClient, "create_batch_prediction_job" + job_service_client.JobServiceClient, "create_batch_prediction_job" ) as create_batch_prediction_job_mock: batch_prediction_job_mock = mock.Mock( - spec=batch_prediction_job.BatchPredictionJob + spec=gca_batch_prediction_job.BatchPredictionJob ) batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME create_batch_prediction_job_mock.return_value = batch_prediction_job_mock yield create_batch_prediction_job_mock +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + batch_prediction_job_mock = mock.Mock( + spec=gca_batch_prediction_job_v1beta1.BatchPredictionJob + ) + batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + create_batch_prediction_job_mock.return_value = batch_prediction_job_mock + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client" + ) as create_client_mock: + api_client_mock = mock.Mock(spec=model_service_client.ModelServiceClient) + create_client_mock.return_value = api_client_mock + yield create_client_mock + + class TestModel: def setup_method(self): importlib.reload(initializer) @@ -197,155 +371,101 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def test_constructor_creates_client(self): + def test_constructor_creates_client(self, create_client_mock): aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, ) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - models.Model(_TEST_ID) - create_client_mock.assert_called_once_with( - client_class=ModelServiceClient, - credentials=initializer.global_config.credentials, - location_override=_TEST_LOCATION, - prediction_client=False, - ) + models.Model(_TEST_ID) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + prediction_client=False, + ) - def test_constructor_create_client_with_custom_location(self): + def test_constructor_create_client_with_custom_location(self, create_client_mock): aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, ) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - - models.Model(_TEST_ID, location=_TEST_LOCATION_2) - create_client_mock.assert_called_once_with( - client_class=ModelServiceClient, - credentials=initializer.global_config.credentials, - location_override=_TEST_LOCATION_2, - prediction_client=False, - ) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION_2, + prediction_client=False, + ) - def test_constructor_creates_client_with_custom_credentials(self): + def test_constructor_creates_client_with_custom_credentials( + self, create_client_mock + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - creds = auth_credentials.AnonymousCredentials() - models.Model(_TEST_ID, credentials=creds) - create_client_mock.assert_called_once_with( - client_class=ModelServiceClient, - credentials=creds, - location_override=_TEST_LOCATION, - prediction_client=False, - ) + creds = auth_credentials.AnonymousCredentials() + models.Model(_TEST_ID, credentials=creds) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=False, + ) - def test_constructor_gets_model(self): + def test_constructor_gets_model(self, get_model_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - - models.Model(_TEST_ID) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID - ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + models.Model(_TEST_ID) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) - def test_constructor_gets_model_with_custom_project(self): + def test_constructor_gets_model_with_custom_project(self, get_model_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - models.Model(_TEST_ID, project=_TEST_PROJECT_2) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID - ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + models.Model(_TEST_ID, project=_TEST_PROJECT_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) - def test_constructor_gets_model_with_custom_location(self): + def test_constructor_gets_model_with_custom_location(self, get_model_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - create_client_mock.return_value = api_client_mock - models.Model(_TEST_ID, location=_TEST_LOCATION_2) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID - ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) @pytest.mark.parametrize("sync", [True, False]) - def test_upload_uploads_and_gets_model(self, sync): + def test_upload_uploads_and_gets_model( + self, upload_model_mock, get_model_mock, sync + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - mock_lro = mock.Mock(ga_operation.Operation) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID - ) - mock_lro.result.return_value = model_service.UploadModelResponse( - model=test_model_resource_name - ) - api_client_mock.upload_model.return_value = mock_lro - create_client_mock.return_value = api_client_mock - - # Custom Container workflow, does not pass `artifact_uri` - my_model = models.Model.upload( - display_name=_TEST_MODEL_NAME, - serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, - serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - sync=sync, - ) + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + sync=sync, + ) - if not sync: - my_model.wait() + if not sync: + my_model.wait() - container_spec = gca_model.ModelContainerSpec( - image_uri=_TEST_SERVING_CONTAINER_IMAGE, - predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - ) + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) - managed_model = gca_model.Model( - display_name=_TEST_MODEL_NAME, container_spec=container_spec, - ) + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, container_spec=container_spec, + ) - api_client_mock.upload_model.assert_called_once_with( - parent=initializer.global_config.common_location_path(), - model=managed_model, - ) + upload_model_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) def test_upload_raises_with_impartial_explanation_spec(self): @@ -363,192 +483,171 @@ def test_upload_raises_with_impartial_explanation_spec(self): assert e.match(regexp=r"`explanation_parameters` should be specified or None.") @pytest.mark.parametrize("sync", [True, False]) - def test_upload_uploads_and_gets_model_with_all_args(self, sync): + def test_upload_uploads_and_gets_model_with_all_args( + self, upload_model_with_explanations_mock, get_model_mock, sync + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - mock_lro = mock.Mock(ga_operation.Operation) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID - ) - mock_lro.result.return_value = model_service.UploadModelResponse( - model=test_model_resource_name - ) - api_client_mock.upload_model.return_value = mock_lro - create_client_mock.return_value = api_client_mock - my_model = models.Model.upload( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, - serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, + description=_TEST_DESCRIPTION, + serving_container_command=_TEST_SERVING_CONTAINER_COMMAND, + serving_container_args=_TEST_SERVING_CONTAINER_ARGS, + serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + my_model.wait() + + env = [ + gca_env_var_v1beta1.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model_v1beta1.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_PORTS + ] + + container_spec = gca_model_v1beta1.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_SERVING_CONTAINER_COMMAND, + args=_TEST_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + managed_model = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + predict_schemata=gca_model_v1beta1.PredictSchemata( instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, - description=_TEST_DESCRIPTION, - serving_container_command=_TEST_SERVING_CONTAINER_COMMAND, - serving_container_args=_TEST_SERVING_CONTAINER_ARGS, - serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, - serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, - explanation_metadata=_TEST_EXPLANATION_METADATA, - explanation_parameters=_TEST_EXPLANATION_PARAMETERS, - sync=sync, - ) - - if not sync: - my_model.wait() - - env = [ - env_var.EnvVar(name=str(key), value=str(value)) - for key, value in _TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() - ] - - ports = [ - gca_model.Port(container_port=port) - for port in _TEST_SERVING_CONTAINER_PORTS - ] - - container_spec = gca_model.ModelContainerSpec( - image_uri=_TEST_SERVING_CONTAINER_IMAGE, - predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - command=_TEST_SERVING_CONTAINER_COMMAND, - args=_TEST_SERVING_CONTAINER_ARGS, - env=env, - ports=ports, - ) - - managed_model = gca_model.Model( - display_name=_TEST_MODEL_NAME, - description=_TEST_DESCRIPTION, - artifact_uri=_TEST_ARTIFACT_URI, - container_spec=container_spec, - predict_schemata=gca_model.PredictSchemata( - instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, - parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, - prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, - ), - explanation_spec=gca_model.explanation.ExplanationSpec( - metadata=_TEST_EXPLANATION_METADATA, - parameters=_TEST_EXPLANATION_PARAMETERS, - ), - ) - - api_client_mock.upload_model.assert_called_once_with( - parent=initializer.global_config.common_location_path(), - model=managed_model, - ) + ), + explanation_spec=gca_model_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + upload_model_with_explanations_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + @pytest.mark.usefixtures("get_model_with_custom_project_mock") @pytest.mark.parametrize("sync", [True, False]) - def test_upload_uploads_and_gets_model_with_custom_project(self, sync): + def test_upload_uploads_and_gets_model_with_custom_project( + self, + upload_model_with_custom_project_mock, + get_model_with_custom_project_mock, + sync, + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - mock_lro = mock.Mock(ga_operation.Operation) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID - ) - mock_lro.result.return_value = model_service.UploadModelResponse( - model=test_model_resource_name - ) - api_client_mock.upload_model.return_value = mock_lro - create_client_mock.return_value = api_client_mock - my_model = models.Model.upload( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, - serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - project=_TEST_PROJECT_2, - sync=sync, - ) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) - if not sync: - my_model.wait() + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + project=_TEST_PROJECT_2, + sync=sync, + ) - container_spec = gca_model.ModelContainerSpec( - image_uri=_TEST_SERVING_CONTAINER_IMAGE, - predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - ) + if not sync: + my_model.wait() - managed_model = gca_model.Model( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - container_spec=container_spec, - ) + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) - api_client_mock.upload_model.assert_called_once_with( - parent=f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION}", - model=managed_model, - ) + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + upload_model_with_custom_project_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION}", + model=managed_model, + ) + get_model_with_custom_project_mock.assert_called_once_with( + name=test_model_resource_name + ) + + @pytest.mark.usefixtures("get_model_with_custom_location_mock") @pytest.mark.parametrize("sync", [True, False]) - def test_upload_uploads_and_gets_model_with_custom_location(self, sync): + def test_upload_uploads_and_gets_model_with_custom_location( + self, + upload_model_with_custom_location_mock, + get_model_with_custom_location_mock, + sync, + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with mock.patch.object( - initializer.global_config, "create_client" - ) as create_client_mock: - api_client_mock = mock.Mock(spec=ModelServiceClient) - mock_lro = mock.Mock(ga_operation.Operation) - test_model_resource_name = ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID - ) - mock_lro.result.return_value = model_service.UploadModelResponse( - model=test_model_resource_name - ) - api_client_mock.upload_model.return_value = mock_lro - create_client_mock.return_value = api_client_mock + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) - my_model = models.Model.upload( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, - serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - location=_TEST_LOCATION_2, - sync=sync, - ) + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + location=_TEST_LOCATION_2, + sync=sync, + ) - if not sync: - my_model.wait() + if not sync: + my_model.wait() - container_spec = gca_model.ModelContainerSpec( - image_uri=_TEST_SERVING_CONTAINER_IMAGE, - predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, - health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, - ) + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) - managed_model = gca_model.Model( - display_name=_TEST_MODEL_NAME, - artifact_uri=_TEST_ARTIFACT_URI, - container_spec=container_spec, - ) + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) - api_client_mock.upload_model.assert_called_once_with( - parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION_2}", - model=managed_model, - ) + upload_model_with_custom_location_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION_2}", + model=managed_model, + ) - api_client_mock.get_model.assert_called_once_with( - name=test_model_resource_name - ) + get_model_with_custom_location_mock.assert_called_once_with( + name=test_model_resource_name + ) @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) @@ -562,7 +661,7 @@ def test_deploy(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -589,7 +688,7 @@ def test_deploy_no_endpoint(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - automatic_resources = machine_resources.AutomaticResources( + automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( @@ -621,12 +720,12 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - expected_machine_spec = machine_resources.MachineSpec( + expected_machine_spec = gca_machine_resources.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ) - expected_dedicated_resources = machine_resources.DedicatedResources( + expected_dedicated_resources = gca_machine_resources.DedicatedResources( machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 ) expected_deployed_model = gca_endpoint.DeployedModel( @@ -645,7 +744,9 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" ) @pytest.mark.parametrize("sync", [True, False]) - def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync): + def test_deploy_no_endpoint_with_explanations( + self, deploy_model_with_explanations_mock, sync + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy( @@ -660,24 +761,24 @@ def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() - expected_machine_spec = machine_resources.MachineSpec( + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ) - expected_dedicated_resources = machine_resources.DedicatedResources( + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 ) - expected_deployed_model = gca_endpoint.DeployedModel( + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( dedicated_resources=expected_dedicated_resources, model=test_model.resource_name, display_name=None, - explanation_spec=gca_endpoint.explanation.ExplanationSpec( + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), ) - deploy_model_mock.assert_called_once_with( + deploy_model_with_explanations_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=expected_deployed_model, traffic_split={"0": 100}, @@ -726,19 +827,17 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = gapic_types.BatchPredictionJob( + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, - model=ModelServiceClient.model_path( + model=model_service_client.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION, _TEST_ID ), - input_config=gapic_types.BatchPredictionJob.InputConfig( + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( instances_format="jsonl", - gcs_source=gapic_types.GcsSource( - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] - ), + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), ), - output_config=gapic_types.BatchPredictionJob.OutputConfig( - gcs_destination=gapic_types.GcsDestination( + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="jsonl", @@ -773,19 +872,17 @@ def test_batch_predict_gcs_source_and_dest( batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = gapic_types.BatchPredictionJob( + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, - model=ModelServiceClient.model_path( + model=model_service_client.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION, _TEST_ID ), - input_config=gapic_types.BatchPredictionJob.InputConfig( + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( instances_format="jsonl", - gcs_source=gapic_types.GcsSource( - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] - ), + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), ), - output_config=gapic_types.BatchPredictionJob.OutputConfig( - gcs_destination=gapic_types.GcsDestination( + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="jsonl", @@ -817,19 +914,17 @@ def test_batch_predict_gcs_source_bq_dest( batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = gapic_types.BatchPredictionJob( + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, - model=ModelServiceClient.model_path( + model=model_service_client.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION, _TEST_ID ), - input_config=gapic_types.BatchPredictionJob.InputConfig( + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( instances_format="jsonl", - gcs_source=gapic_types.GcsSource( - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] - ), + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), ), - output_config=gapic_types.BatchPredictionJob.OutputConfig( - bigquery_destination=gapic_types.BigQueryDestination( + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io.BigQueryDestination( output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL ), predictions_format="bigquery", @@ -843,7 +938,9 @@ def test_batch_predict_gcs_source_bq_dest( @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") - def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, sync): + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) creds = auth_credentials.AnonymousCredentials() @@ -873,25 +970,25 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = gapic_types.BatchPredictionJob( + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, - model=ModelServiceClient.model_path( + model=model_service_client_v1beta1.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION, _TEST_ID ), - input_config=gapic_types.BatchPredictionJob.InputConfig( + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( instances_format="jsonl", - gcs_source=gapic_types.GcsSource( + gcs_source=gca_io_v1beta1.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), - output_config=gapic_types.BatchPredictionJob.OutputConfig( - gcs_destination=gapic_types.GcsDestination( + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="csv", ), - dedicated_resources=gapic_types.BatchDedicatedResources( - machine_spec=gapic_types.MachineSpec( + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, @@ -900,15 +997,15 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn max_replica_count=_TEST_MAX_REPLICA_COUNT, ), generate_explanation=True, - explanation_spec=gapic_types.ExplanationSpec( + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, - encryption_spec=_TEST_ENCRYPTION_SPEC, + encryption_spec=_TEST_ENCRYPTION_SPEC_V1BETA1, ) - create_batch_prediction_job_mock.assert_called_once_with( + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", batch_prediction_job=expected_gapic_batch_prediction_job, ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 42458a9dc0..33d43321ef 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -31,26 +31,28 @@ from google.auth import credentials as auth_credentials from google.cloud import aiplatform + from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs -from google.cloud.aiplatform_v1beta1.services.model_service import ( + +from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( +from google.cloud.aiplatform_v1.services.pipeline_service import ( client as pipeline_service_client, ) -from google.cloud.aiplatform_v1beta1.types import io as gca_io -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import pipeline_state as gca_pipeline_state -from google.cloud.aiplatform_v1beta1.types import ( +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, ) -from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec -from google.cloud.aiplatform_v1beta1 import Dataset as GapicDataset from google.cloud import storage from google.protobuf import json_format @@ -490,7 +492,7 @@ def mock_tabular_dataset(): ds = mock.MagicMock(datasets.Dataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None - ds._gca_resource = GapicDataset( + ds._gca_resource = gca_dataset.Dataset( display_name=_TEST_DATASET_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, labels={}, @@ -505,7 +507,7 @@ def mock_nontabular_dataset(): ds = mock.MagicMock(datasets.Dataset) ds.name = _TEST_DATASET_NAME ds._latest_future = None - ds._gca_resource = GapicDataset( + ds._gca_resource = gca_dataset.Dataset( display_name=_TEST_DATASET_DISPLAY_NAME, metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, labels={}, @@ -608,7 +610,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -758,7 +760,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -1457,7 +1459,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -1689,7 +1691,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -1835,7 +1837,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -2450,7 +2452,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -2882,7 +2884,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -3030,7 +3032,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] @@ -3660,7 +3662,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( ) env = [ - env_var.EnvVar(name=str(key), value=str(value)) + gca_env_var.EnvVar(name=str(key), value=str(value)) for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() ] diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index a81180738f..3032475069 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -25,11 +25,18 @@ from google.api_core import client_options from google.api_core import gapic_v1 from google.cloud import aiplatform +from google.cloud.aiplatform import compat from google.cloud.aiplatform import utils + from google.cloud.aiplatform_v1beta1.services.model_service import ( - client as model_service_client, + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, ) +model_service_client_default = model_service_client_v1 + @pytest.mark.parametrize( "resource_name, expected", @@ -251,12 +258,48 @@ def test_wrapped_client(): test_client_info = gapic_v1.client_info.ClientInfo() test_client_options = client_options.ClientOptions() - wrapped_client = utils.WrappedClient( - client_class=model_service_client.ModelServiceClient, + wrapped_client = utils.ClientWithOverride.WrappedClient( + client_class=model_service_client_default.ModelServiceClient, client_options=test_client_options, client_info=test_client_info, ) assert isinstance( - wrapped_client.get_model.__self__, model_service_client.ModelServiceClient + wrapped_client.get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_default_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + assert isinstance( + client_w_override._clients[ + client_w_override._default_version + ].get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_select_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + + assert isinstance( + client_w_override.select_version(compat.V1BETA1).get_model.__self__, + model_service_client_v1beta1.ModelServiceClient, + ) + assert isinstance( + client_w_override.select_version(compat.V1).get_model.__self__, + model_service_client_v1.ModelServiceClient, )