From 0f6b6701e96fb0ec345e81560d03059a7900160f Mon Sep 17 00:00:00 2001 From: Eugene Kim Date: Tue, 25 Jan 2022 06:14:32 -0800 Subject: [PATCH 1/6] feat: Enable europe-west6 and northamerica-northeast2 regions Co-authored-by: Eugene Kim --- google/cloud/aiplatform/constants/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/google/cloud/aiplatform/constants/base.py b/google/cloud/aiplatform/constants/base.py index 34ca06a4b59..5e09fce1d03 100644 --- a/google/cloud/aiplatform/constants/base.py +++ b/google/cloud/aiplatform/constants/base.py @@ -28,7 +28,9 @@ "europe-west2", "europe-west3", "europe-west4", + "europe-west6", "northamerica-northeast1", + "northamerica-northeast2", "us-central1", "us-east1", "us-east4", From c840728e503eea3300e9629405978e28c6aafec7 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Wed, 26 Jan 2022 20:16:22 -0800 Subject: [PATCH 2/6] feat: enable feature store batch serve to BigQuery and GCS for csv and tfrecord (#919) * feat: add batch_serve_to_bq for bigquery table and batch_serve_to_gcs for csv and tfrecord files in Featurestore class * fix: change entity_type_ids and entity_type_destination_fields to serving_feature_ids and feature_destination_fields * fix: remove white space * Update google/cloud/aiplatform/featurestore/featurestore.py Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> * Update google/cloud/aiplatform/featurestore/featurestore.py Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> * Update google/cloud/aiplatform/featurestore/featurestore.py Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> * Update google/cloud/aiplatform/featurestore/featurestore.py Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> * Update google/cloud/aiplatform/featurestore/featurestore.py Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> * fix: Featurestore create method example usage * fix: get_timestamp_proto for millisecond precision cap * fix: unit tests for get_timestamp_proto Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> --- .../aiplatform/featurestore/featurestore.py | 477 +++++++++++++++++- google/cloud/aiplatform/utils/__init__.py | 12 +- .../aiplatform/utils/featurestore_utils.py | 1 + tests/system/aiplatform/e2e_base.py | 33 ++ tests/system/aiplatform/test_featurestore.py | 116 ++++- tests/unit/aiplatform/test_featurestores.py | 334 ++++++++++++ tests/unit/aiplatform/test_utils.py | 26 +- 7 files changed, 959 insertions(+), 40 deletions(-) diff --git a/google/cloud/aiplatform/featurestore/featurestore.py b/google/cloud/aiplatform/featurestore/featurestore.py index 4b98ccfd7d2..6d02bb9f76f 100644 --- a/google/cloud/aiplatform/featurestore/featurestore.py +++ b/google/cloud/aiplatform/featurestore/featurestore.py @@ -15,13 +15,18 @@ # limitations under the License. # -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union from google.auth import credentials as auth_credentials from google.protobuf import field_mask_pb2 from google.cloud.aiplatform import base -from google.cloud.aiplatform.compat.types import featurestore as gca_featurestore +from google.cloud.aiplatform.compat.types import ( + feature_selector as gca_feature_selector, + featurestore as gca_featurestore, + featurestore_service as gca_featurestore_service, + io as gca_io, +) from google.cloud.aiplatform import featurestore from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils @@ -384,14 +389,8 @@ def create( Example Usage: - my_entity_type = aiplatform.EntityType.create( - entity_type_id='my_entity_type_id', - featurestore_name='projects/123/locations/us-central1/featurestores/my_featurestore_id' - ) - or - my_entity_type = aiplatform.EntityType.create( - entity_type_id='my_entity_type_id', - featurestore_name='my_featurestore_id', + my_featurestore = aiplatform.Featurestore.create( + featurestore_id='my_featurestore_id', ) Args: @@ -556,3 +555,461 @@ def create_entity_type( request_metadata=request_metadata, sync=sync, ) + + def _batch_read_feature_values( + self, + batch_read_feature_values_request: gca_featurestore_service.BatchReadFeatureValuesRequest, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + ) -> "Featurestore": + """Batch read Feature values from the Featurestore to a destination storage. + + Args: + batch_read_feature_values_request (gca_featurestore_service.BatchReadFeatureValuesRequest): + Required. Request of batch read feature values. + request_metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as metadata. + + Returns: + Featurestore: The featurestore resource object batch read feature values from. + """ + + _LOGGER.log_action_start_against_resource( + "Serving", "feature values", self, + ) + + batch_read_lro = self.api_client.batch_read_feature_values( + request=batch_read_feature_values_request, metadata=request_metadata, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Serve", "feature values", self.__class__, batch_read_lro + ) + + batch_read_lro.result() + + _LOGGER.log_action_completed_against_resource("feature values", "served", self) + + return self + + def _validate_and_get_batch_read_feature_values_request( + self, + serving_feature_ids: Dict[str, List[str]], + destination: Union[ + gca_io.BigQueryDestination, + gca_io.CsvDestination, + gca_io.TFRecordDestination, + ], + feature_destination_fields: Optional[Dict[str, str]] = None, + read_instances: Optional[Union[gca_io.BigQuerySource, gca_io.CsvSource]] = None, + pass_through_fields: Optional[List[str]] = None, + ) -> gca_featurestore_service.BatchReadFeatureValuesRequest: + """Validates and gets batch_read_feature_values_request + + Args: + serving_feature_ids (Dict[str, List[str]]): + Required. A user defined dictionary to define the entity_types and their features for batch serve/read. + The keys of the dictionary are the serving entity_type ids and + the values are lists of serving feature ids in each entity_type. + + Example: + serving_feature_ids = { + 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'], + 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'], + } + + destination (Union[gca_io.BigQueryDestination, gca_io.CsvDestination, gca_io.TFRecordDestination]): + Required. BigQuery destination, Csv destination or TFRecord destination. + + feature_destination_fields (Dict[str, str]): + Optional. A user defined dictionary to map a feature's fully qualified resource name to + its destination field name. If the destination field name is not defined, + the feature ID will be used as its destination field name. + + Example: + feature_destination_fields = { + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo', + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar', + } + + read_instances (Union[gca_io.BigQuerySource, gca_io.CsvSource]): + Optional. BigQuery source or Csv source for read instances. + pass_through_fields (List[str]): + Optional. When not empty, the specified fields in the + read_instances source will be joined as-is in the output, + in addition to those fields from the Featurestore Entity. + + For BigQuery source, the type of the pass-through values + will be automatically inferred. For CSV source, the + pass-through values will be passed as opaque bytes. + + Returns: + gca_featurestore_service.BatchReadFeatureValuesRequest: batch read feature values request + """ + + featurestore_name_components = self._parse_resource_name(self.resource_name) + + feature_destination_fields = feature_destination_fields or {} + + entity_type_specs = [] + for entity_type_id, feature_ids in serving_feature_ids.items(): + destination_feature_settings = [] + for feature_id in feature_ids: + feature_resource_name = featurestore.Feature._format_resource_name( + project=featurestore_name_components["project"], + location=featurestore_name_components["location"], + featurestore=featurestore_name_components["featurestore"], + entity_type=entity_type_id, + feature=feature_id, + ) + + feature_destination_field = feature_destination_fields.get( + feature_resource_name + ) + if feature_destination_field: + destination_feature_setting_proto = gca_featurestore_service.DestinationFeatureSetting( + feature_id=feature_id, + destination_field=feature_destination_field, + ) + destination_feature_settings.append( + destination_feature_setting_proto + ) + + entity_type_spec = gca_featurestore_service.BatchReadFeatureValuesRequest.EntityTypeSpec( + entity_type_id=entity_type_id, + feature_selector=gca_feature_selector.FeatureSelector( + id_matcher=gca_feature_selector.IdMatcher(ids=feature_ids) + ), + settings=destination_feature_settings or None, + ) + entity_type_specs.append(entity_type_spec) + + batch_read_feature_values_request = gca_featurestore_service.BatchReadFeatureValuesRequest( + featurestore=self.resource_name, entity_type_specs=entity_type_specs, + ) + + if isinstance(destination, gca_io.BigQueryDestination): + batch_read_feature_values_request.destination = gca_featurestore_service.FeatureValueDestination( + bigquery_destination=destination + ) + elif isinstance(destination, gca_io.CsvDestination): + batch_read_feature_values_request.destination = gca_featurestore_service.FeatureValueDestination( + csv_destination=destination + ) + elif isinstance(destination, gca_io.TFRecordDestination): + batch_read_feature_values_request.destination = gca_featurestore_service.FeatureValueDestination( + tfrecord_destination=destination + ) + + if isinstance(read_instances, gca_io.BigQuerySource): + batch_read_feature_values_request.bigquery_read_instances = read_instances + elif isinstance(read_instances, gca_io.CsvSource): + batch_read_feature_values_request.csv_read_instances = read_instances + + if pass_through_fields is not None: + batch_read_feature_values_request.pass_through_fields = [ + gca_featurestore_service.BatchReadFeatureValuesRequest.PassThroughField( + field_name=pass_through_field + ) + for pass_through_field in pass_through_fields + ] + + return batch_read_feature_values_request + + def _get_read_instances( + self, read_instances: Union[str, List[str]], + ) -> Union[gca_io.BigQuerySource, gca_io.CsvSource]: + """Gets read_instances + + Args: + read_instances (Union[str, List[str]]): + Required. Read_instances can be either BigQuery URI to the input table, + or Google Cloud Storage URI(-s) to the csv file(s). + + Returns: + Union[gca_io.BigQuerySource, gca_io.CsvSource]: BigQuery source or Csv source for read instances. + + Raises: + TypeError if read_instances is not a string or a list of strings. + ValueError if read_instances uri does not start with 'bq://' or 'gs://'. + ValueError if uris in read_instances do not start with 'gs://'. + """ + if isinstance(read_instances, str): + if not ( + read_instances.startswith("bq://") or read_instances.startswith("gs://") + ): + raise ValueError( + "The read_instances accepts a single uri starts with 'bq://' or 'gs://'." + ) + elif isinstance(read_instances, list) and all( + [isinstance(e, str) for e in read_instances] + ): + if not all([e.startswith("gs://") for e in read_instances]): + raise ValueError( + "The read_instances accepts a list of uris start with 'gs://' only." + ) + else: + raise TypeError( + "The read_instances type should to be either a str or a List[str]." + ) + + if isinstance(read_instances, str): + if read_instances.startswith("bq://"): + return gca_io.BigQuerySource(input_uri=read_instances) + else: + read_instances = [read_instances] + + return gca_io.CsvSource(gcs_source=gca_io.GcsSource(uris=read_instances)) + + @base.optional_sync(return_input_arg="self") + def batch_serve_to_bq( + self, + bq_destination_output_uri: str, + serving_feature_ids: Dict[str, List[str]], + feature_destination_fields: Optional[Dict[str, str]] = None, + read_instances: Optional[Union[str, List[str]]] = None, + pass_through_fields: Optional[List[str]] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync: bool = True, + ) -> "Featurestore": + """ Batch serves feature values to BigQuery destination + + Args: + bq_destination_output_uri (str): + Required. BigQuery URI to the detination table. + + Example: + 'bq://project.dataset.table_name' + + It requires an existing BigQuery destination Dataset, under the same project as the Featurestore. + + serving_feature_ids (Dict[str, List[str]]): + Required. A user defined dictionary to define the entity_types and their features for batch serve/read. + The keys of the dictionary are the serving entity_type ids and + the values are lists of serving feature ids in each entity_type. + + Example: + serving_feature_ids = { + 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'], + 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'], + } + + feature_destination_fields (Dict[str, str]): + Optional. A user defined dictionary to map a feature's fully qualified resource name to + its destination field name. If the destination field name is not defined, + the feature ID will be used as its destination field name. + + Example: + feature_destination_fields = { + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo', + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar', + } + + read_instances (Union[str, List[str]]): + Optional. Read_instances can be either BigQuery URI to the input table, + or Google Cloud Storage URI(-s) to the + csv file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + Example: + 'bq://project.dataset.table_name' + or + ["gs://my_bucket/my_file_1.csv", "gs://my_bucket/my_file_2.csv"] + + Each read instance consists of exactly one read timestamp + and one or more entity IDs identifying entities of the + corresponding EntityTypes whose Features are requested. + + Each output instance contains Feature values of requested + entities concatenated together as of the read time. + + An example read instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z``. + + An example output instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z, foo_entity_feature1_value, bar_entity_feature2_value``. + + Timestamp in each read instance must be millisecond-aligned. + + The columns can be in any order. + + Values in the timestamp column must use the RFC 3339 format, + e.g. ``2012-07-30T10:43:17.123Z``. + + pass_through_fields (List[str]): + Optional. When not empty, the specified fields in the + read_instances source will be joined as-is in the output, + in addition to those fields from the Featurestore Entity. + + For BigQuery source, the type of the pass-through values + will be automatically inferred. For CSV source, the + pass-through values will be passed as opaque bytes. + + Returns: + Featurestore: The featurestore resource object batch read feature values from. + + Raises: + NotFound: if the BigQuery destination Dataset does not exist. + FailedPrecondition: if the BigQuery destination Dataset/Table is in a different project. + """ + batch_read_feature_values_request = self._validate_and_get_batch_read_feature_values_request( + serving_feature_ids=serving_feature_ids, + destination=gca_io.BigQueryDestination( + output_uri=bq_destination_output_uri + ), + feature_destination_fields=feature_destination_fields, + read_instances=read_instances + if read_instances is None + else self._get_read_instances(read_instances), + pass_through_fields=pass_through_fields, + ) + + return self._batch_read_feature_values( + batch_read_feature_values_request=batch_read_feature_values_request, + request_metadata=request_metadata, + ) + + @base.optional_sync(return_input_arg="self") + def batch_serve_to_gcs( + self, + gcs_destination_output_uri_prefix: str, + gcs_destination_type: str, + serving_feature_ids: Dict[str, List[str]], + feature_destination_fields: Optional[Dict[str, str]] = None, + read_instances: Optional[Union[str, List[str]]] = None, + pass_through_fields: Optional[List[str]] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync: bool = True, + ) -> "Featurestore": + """ Batch serves feature values to GCS destination + + Args: + gcs_destination_output_uri_prefix (str): + Required. Google Cloud Storage URI to output + directory. If the uri doesn't end with '/', a + '/' will be automatically appended. The + directory is created if it doesn't exist. + + Example: + "gs://bucket/path/to/prefix" + + gcs_destination_type (str): + Required. The type of the destination files(s), + the value of gcs_destination_type can only be either `csv`, or `tfrecord`. + + For CSV format. Array Feature value types are not allowed in CSV format. + + For TFRecord format. + + Below are the mapping from Feature value type in + Featurestore to Feature value type in TFRecord: + + :: + + Value type in Featurestore | Value type in TFRecord + DOUBLE, DOUBLE_ARRAY | FLOAT_LIST + INT64, INT64_ARRAY | INT64_LIST + STRING, STRING_ARRAY, BYTES | BYTES_LIST + true -> byte_string("true"), false -> byte_string("false") + BOOL, BOOL_ARRAY (true, false) | BYTES_LIST + + serving_feature_ids (Dict[str, List[str]]): + Required. A user defined dictionary to define the entity_types and their features for batch serve/read. + The keys of the dictionary are the serving entity_type ids and + the values are lists of serving feature ids in each entity_type. + + Example: + serving_feature_ids = { + 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'], + 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'], + } + + feature_destination_fields (Dict[str, str]): + Optional. A user defined dictionary to map a feature's fully qualified resource name to + its destination field name. If the destination field name is not defined, + the feature ID will be used as its destination field name. + + Example: + feature_destination_fields = { + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo', + 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar', + } + + read_instances (Union[str, List[str]]): + Optional. Read_instances can be either BigQuery URI to the input table, + or Google Cloud Storage URI(-s) to the + csv file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + Example: + 'bq://project.dataset.table_name' + or + ["gs://my_bucket/my_file_1.csv", "gs://my_bucket/my_file_2.csv"] + + Each read instance consists of exactly one read timestamp + and one or more entity IDs identifying entities of the + corresponding EntityTypes whose Features are requested. + + Each output instance contains Feature values of requested + entities concatenated together as of the read time. + + An example read instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z``. + + An example output instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z, foo_entity_feature1_value, bar_entity_feature2_value``. + + Timestamp in each read instance must be millisecond-aligned. + + The columns can be in any order. + + Values in the timestamp column must use the RFC 3339 format, + e.g. ``2012-07-30T10:43:17.123Z``. + + pass_through_fields (List[str]): + Optional. When not empty, the specified fields in the + read_instances source will be joined as-is in the output, + in addition to those fields from the Featurestore Entity. + + For BigQuery source, the type of the pass-through values + will be automatically inferred. For CSV source, the + pass-through values will be passed as opaque bytes. + + Returns: + Featurestore: The featurestore resource object batch read feature values from. + + Raises: + ValueError if gcs_destination_type is not supported. + + """ + destination = None + if gcs_destination_type not in featurestore_utils.GCS_DESTINATION_TYPE: + raise ValueError( + "Only %s are supported gcs_destination_type, not `%s`. " + % ( + "`" + "`, `".join(featurestore_utils.GCS_DESTINATION_TYPE) + "`", + gcs_destination_type, + ) + ) + + gcs_destination = gca_io.GcsDestination( + output_uri_prefix=gcs_destination_output_uri_prefix + ) + if gcs_destination_type == "csv": + destination = gca_io.CsvDestination(gcs_destination=gcs_destination) + if gcs_destination_type == "tfrecord": + destination = gca_io.TFRecordDestination(gcs_destination=gcs_destination) + + batch_read_feature_values_request = self._validate_and_get_batch_read_feature_values_request( + serving_feature_ids=serving_feature_ids, + destination=destination, + feature_destination_fields=feature_destination_fields, + read_instances=read_instances + if read_instances is None + else self._get_read_instances(read_instances), + pass_through_fields=pass_through_fields, + ) + + return self._batch_read_feature_values( + batch_read_feature_values_request=batch_read_feature_values_request, + request_metadata=request_metadata, + ) diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index 26b28dcdd7e..c5c21a2a0b0 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -628,9 +628,11 @@ def get_timestamp_proto( """ if not time: time = datetime.datetime.now() - t = time.timestamp() - seconds = int(t) - # must not have higher than millisecond precision. - nanos = int((t % 1 * 1e6) * 1e3) - return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + time_str = time.isoformat(sep=" ", timespec="milliseconds") + time = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S.%f") + + timestamp_proto = timestamp_pb2.Timestamp() + timestamp_proto.FromDatetime(time) + + return timestamp_proto diff --git a/google/cloud/aiplatform/utils/featurestore_utils.py b/google/cloud/aiplatform/utils/featurestore_utils.py index e9d26b62bed..45dbbbf44f5 100644 --- a/google/cloud/aiplatform/utils/featurestore_utils.py +++ b/google/cloud/aiplatform/utils/featurestore_utils.py @@ -29,6 +29,7 @@ RESOURCE_ID_PATTERN_REGEX = r"[a-z_][a-z0-9_]{0,59}" GCS_SOURCE_TYPE = {"csv", "avro"} +GCS_DESTINATION_TYPE = {"csv", "tfrecord"} _FEATURE_VALUE_TYPE_UNSPECIFIED = "VALUE_TYPE_UNSPECIFIED" diff --git a/tests/system/aiplatform/e2e_base.py b/tests/system/aiplatform/e2e_base.py index 61b9e7f36c6..3a9f87e8ae5 100644 --- a/tests/system/aiplatform/e2e_base.py +++ b/tests/system/aiplatform/e2e_base.py @@ -24,6 +24,7 @@ from google.api_core import exceptions from google.cloud import aiplatform +from google.cloud import bigquery from google.cloud import storage from google.cloud.aiplatform import initializer @@ -90,6 +91,38 @@ def delete_staging_bucket(self, shared_state: Dict[str, Any]): bucket = shared_state["bucket"] bucket.delete(force=True) + @pytest.fixture(scope="class") + def prepare_bigquery_dataset( + self, shared_state: Dict[str, Any] + ) -> Generator[bigquery.dataset.Dataset, None, None]: + """Create a bigquery dataset and store bigquery resource object in shared state.""" + + bigquery_client = bigquery.Client(project=_PROJECT) + shared_state["bigquery_client"] = bigquery_client + + dataset_name = f"{self._temp_prefix.lower()}_{uuid.uuid4()}".replace("-", "_") + dataset_id = f"{_PROJECT}.{dataset_name}" + shared_state["bigquery_dataset_id"] = dataset_id + + dataset = bigquery.Dataset(dataset_id) + dataset.location = _LOCATION + shared_state["bigquery_dataset"] = bigquery_client.create_dataset(dataset) + + yield + + @pytest.fixture(scope="class") + def delete_bigquery_dataset(self, shared_state: Dict[str, Any]): + """Delete the bigquery dataset""" + + yield + + # Get the bigquery dataset id used for testing and wipe it + bigquery_dataset = shared_state["bigquery_dataset"] + bigquery_client = shared_state["bigquery_client"] + bigquery_client.delete_dataset( + bigquery_dataset.dataset_id, delete_contents=True, not_found_ok=True + ) # Make an API request. + @pytest.fixture(scope="class", autouse=True) def teardown(self, shared_state: Dict[str, Any]): """Delete every Vertex AI resource created during test""" diff --git a/tests/system/aiplatform/test_featurestore.py b/tests/system/aiplatform/test_featurestore.py index d22119ea223..b67dec6883f 100644 --- a/tests/system/aiplatform/test_featurestore.py +++ b/tests/system/aiplatform/test_featurestore.py @@ -16,6 +16,7 @@ # import logging +import pytest from google.cloud import aiplatform from tests.system.aiplatform import e2e_base @@ -29,6 +30,8 @@ "gs://cloud-samples-data-us-central1/vertex-ai/feature-store/datasets/movies.avro" ) +_TEST_READ_INSTANCE_SRC = "gs://cloud-samples-data-us-central1/vertex-ai/feature-store/datasets/movie_prediction.csv" + _TEST_FEATURESTORE_ID = "movie_prediction" _TEST_USER_ENTITY_TYPE_ID = "users" _TEST_MOVIE_ENTITY_TYPE_ID = "movies" @@ -42,6 +45,12 @@ _TEST_MOVIE_AVERAGE_RATING_FEATURE_ID = "average_rating" +@pytest.mark.usefixtures( + "prepare_staging_bucket", + "delete_staging_bucket", + "prepare_bigquery_dataset", + "delete_bigquery_dataset", +) class TestFeaturestore(e2e_base.TestEndToEnd): _temp_prefix = "temp_vertex_sdk_e2e_featurestore_test" @@ -131,7 +140,7 @@ def test_create_get_list_features(self, shared_state): user_age_feature = user_entity_type.create_feature( feature_id=_TEST_USER_AGE_FEATURE_ID, value_type="INT64" ) - + shared_state["user_age_feature_resource_name"] = user_age_feature.resource_name get_user_age_feature = user_entity_type.get_feature( feature_id=_TEST_USER_AGE_FEATURE_ID ) @@ -142,6 +151,9 @@ def test_create_get_list_features(self, shared_state): value_type="STRING", entity_type_name=user_entity_type_name, ) + shared_state[ + "user_gender_feature_resource_name" + ] = user_gender_feature.resource_name get_user_gender_feature = aiplatform.Feature( feature_name=user_gender_feature.resource_name @@ -153,6 +165,9 @@ def test_create_get_list_features(self, shared_state): user_liked_genres_feature = user_entity_type.create_feature( feature_id=_TEST_USER_LIKED_GENRES_FEATURE_ID, value_type="STRING_ARRAY", ) + shared_state[ + "user_liked_genres_feature_resource_name" + ] = user_liked_genres_feature.resource_name get_user_liked_genres_feature = aiplatform.Feature( feature_name=user_liked_genres_feature.resource_name @@ -250,6 +265,105 @@ def test_search_features(self, shared_state): len(list_searched_features) - shared_state["base_list_searched_features"] ) == 6 + def test_batch_serve_to_gcs(self, shared_state, caplog): + + assert shared_state["featurestore"] + assert shared_state["bucket"] + assert shared_state["user_age_feature_resource_name"] + assert shared_state["user_gender_feature_resource_name"] + assert shared_state["user_liked_genres_feature_resource_name"] + + featurestore = shared_state["featurestore"] + bucket_name = shared_state["staging_bucket_name"] + user_age_feature_resource_name = shared_state["user_age_feature_resource_name"] + user_gender_feature_resource_name = shared_state[ + "user_gender_feature_resource_name" + ] + user_liked_genres_feature_resource_name = shared_state[ + "user_liked_genres_feature_resource_name" + ] + + aiplatform.init( + project=e2e_base._PROJECT, location=e2e_base._LOCATION, + ) + + caplog.set_level(logging.INFO) + + featurestore.batch_serve_to_gcs( + serving_feature_ids={ + _TEST_USER_ENTITY_TYPE_ID: [ + _TEST_USER_AGE_FEATURE_ID, + _TEST_USER_GENDER_FEATURE_ID, + _TEST_USER_LIKED_GENRES_FEATURE_ID, + ], + _TEST_MOVIE_ENTITY_TYPE_ID: [ + _TEST_MOVIE_TITLE_FEATURE_ID, + _TEST_MOVIE_GENRES_FEATURE_ID, + _TEST_MOVIE_AVERAGE_RATING_FEATURE_ID, + ], + }, + feature_destination_fields={ + user_age_feature_resource_name: "user_age_dest", + user_gender_feature_resource_name: "user_gender_dest", + user_liked_genres_feature_resource_name: "user_liked_genres_dest", + }, + read_instances=_TEST_READ_INSTANCE_SRC, + gcs_destination_output_uri_prefix=f"gs://{bucket_name}/featurestore_test/tfrecord", + gcs_destination_type="tfrecord", + ) + assert "Featurestore feature values served." in caplog.text + + caplog.clear() + + def test_batch_serve_to_bq(self, shared_state, caplog): + + assert shared_state["featurestore"] + assert shared_state["bigquery_dataset"] + assert shared_state["user_age_feature_resource_name"] + assert shared_state["user_gender_feature_resource_name"] + assert shared_state["user_liked_genres_feature_resource_name"] + + featurestore = shared_state["featurestore"] + bigquery_dataset_id = shared_state["bigquery_dataset_id"] + user_age_feature_resource_name = shared_state["user_age_feature_resource_name"] + user_gender_feature_resource_name = shared_state[ + "user_gender_feature_resource_name" + ] + user_liked_genres_feature_resource_name = shared_state[ + "user_liked_genres_feature_resource_name" + ] + + aiplatform.init( + project=e2e_base._PROJECT, location=e2e_base._LOCATION, + ) + + caplog.set_level(logging.INFO) + + featurestore.batch_serve_to_bq( + serving_feature_ids={ + _TEST_USER_ENTITY_TYPE_ID: [ + _TEST_USER_AGE_FEATURE_ID, + _TEST_USER_GENDER_FEATURE_ID, + _TEST_USER_LIKED_GENRES_FEATURE_ID, + ], + _TEST_MOVIE_ENTITY_TYPE_ID: [ + _TEST_MOVIE_TITLE_FEATURE_ID, + _TEST_MOVIE_GENRES_FEATURE_ID, + _TEST_MOVIE_AVERAGE_RATING_FEATURE_ID, + ], + }, + feature_destination_fields={ + user_age_feature_resource_name: "user_age_dest", + user_gender_feature_resource_name: "user_gender_dest", + user_liked_genres_feature_resource_name: "user_liked_genres_dest", + }, + read_instances=_TEST_READ_INSTANCE_SRC, + bq_destination_output_uri=f"bq://{bigquery_dataset_id}.test_table", + ) + + assert "Featurestore feature values served." in caplog.text + caplog.clear() + def test_online_reads(self, shared_state): assert shared_state["user_entity_type"] assert shared_state["movie_entity_type"] diff --git a/tests/unit/aiplatform/test_featurestores.py b/tests/unit/aiplatform/test_featurestores.py index a92043969e3..97cec0056f0 100644 --- a/tests/unit/aiplatform/test_featurestores.py +++ b/tests/unit/aiplatform/test_featurestores.py @@ -207,6 +207,11 @@ "my_feature_id_1": "my_feature_id_1_source_field", } +_TEST_SERVING_FEATURE_IDS = { + "my_entity_type_id_1": ["my_feature_id_1_1", "my_feature_id_1_2"], + "my_entity_type_id_2": ["my_feature_id_2_1", "my_feature_id_2_2"], +} + _TEST_FEATURE_TIME_FIELD = "feature_time_field" _TEST_FEATURE_TIME = datetime.datetime.now() @@ -214,6 +219,7 @@ _TEST_GCS_AVRO_SOURCE_URIS = [ "gs://my_bucket/my_file_1.avro", ] +_TEST_GCS_CSV_SOURCE_URI = "gs://my_bucket/my_file_1.csv" _TEST_GCS_CSV_SOURCE_URIS = [ "gs://my_bucket/my_file_1.csv", ] @@ -221,6 +227,13 @@ _TEST_GCS_SOURCE_TYPE_AVRO = "avro" _TEST_GCS_SOURCE_TYPE_INVALID = "json" +_TEST_BQ_DESTINATION_URI = "bq://project.dataset.table_name" +_TEST_GCS_OUTPUT_URI_PREFIX = "gs://my_bucket/path/to_prefix" + +_TEST_GCS_DESTINATION_TYPE_CSV = "csv" +_TEST_GCS_DESTINATION_TYPE_TFRECORD = "tfrecord" +_TEST_GCS_DESTINATION_TYPE_INVALID = "json" + _TEST_BQ_SOURCE = gca_io.BigQuerySource(input_uri=_TEST_BQ_SOURCE_URI) _TEST_AVRO_SOURCE = gca_io.AvroSource( gcs_source=gca_io.GcsSource(uris=_TEST_GCS_AVRO_SOURCE_URIS) @@ -229,6 +242,14 @@ gcs_source=gca_io.GcsSource(uris=_TEST_GCS_CSV_SOURCE_URIS) ) +_TEST_BQ_DESTINATION = gca_io.BigQueryDestination(output_uri=_TEST_BQ_DESTINATION_URI) +_TEST_CSV_DESTINATION = gca_io.CsvDestination( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX) +) +_TEST_TFRECORD_DESTINATION = gca_io.TFRecordDestination( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX) +) + _TEST_READ_ENTITY_ID = "entity_id_1" _TEST_READ_ENTITY_IDS = ["entity_id_1"] @@ -243,6 +264,26 @@ ) +def _get_entity_type_spec_proto_with_feature_ids( + entity_type_id, feature_ids, feature_destination_fields=None +): + feature_destination_fields = feature_destination_fields or {} + entity_type_spec_proto = gca_featurestore_service.BatchReadFeatureValuesRequest.EntityTypeSpec( + entity_type_id=entity_type_id, + feature_selector=gca_feature_selector.FeatureSelector( + id_matcher=gca_feature_selector.IdMatcher(ids=feature_ids) + ), + settings=[ + gca_featurestore_service.DestinationFeatureSetting( + feature_id=feature_id, destination_field=feature_destination_field + ) + for feature_id, feature_destination_field in feature_destination_fields.items() + ] + or None, + ) + return entity_type_spec_proto + + def _get_header_proto(feature_ids): header_proto = copy.deepcopy(_TEST_BASE_HEADER_PROTO) header_proto.feature_descriptors = [ @@ -351,6 +392,17 @@ def create_featurestore_mock(): yield create_featurestore_mock +@pytest.fixture +def batch_read_feature_values_mock(): + with patch.object( + featurestore_service_client.FeaturestoreServiceClient, + "batch_read_feature_values", + ) as batch_read_feature_values_mock: + batch_read_feature_values_lro_mock = mock.Mock(operation.Operation) + batch_read_feature_values_mock.return_value = batch_read_feature_values_lro_mock + yield batch_read_feature_values_mock + + # ALL EntityType Mocks @pytest.fixture def get_entity_type_mock(): @@ -875,6 +927,288 @@ def test_create_featurestore(self, create_featurestore_mock, sync): metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_featurestore_mock") + @pytest.mark.parametrize( + "serving_feature_ids, feature_destination_fields, expected_entity_type_specs", + [ + ( + { + "my_entity_type_id_1": ["my_feature_id_1_1", "my_feature_id_1_2"], + "my_entity_type_id_2": ["my_feature_id_2_1", "my_feature_id_2_2"], + }, + None, + [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + ), + ], + ), + ( + { + "my_entity_type_id_1": ["my_feature_id_1_1", "my_feature_id_1_2"], + "my_entity_type_id_2": ["my_feature_id_2_1", "my_feature_id_2_2"], + }, + { + f"{_TEST_FEATURESTORE_NAME}/entityTypes/my_entity_type_id_1/features/my_feature_id_1_1": "my_feature_id_1_1_dest", + f"{_TEST_FEATURESTORE_NAME}/entityTypes/my_entity_type_id_1/features/my_feature_id_1_2": "my_feature_id_1_2_dest", + }, + [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + feature_destination_fields={ + "my_feature_id_1_1": "my_feature_id_1_1_dest", + "my_feature_id_1_2": "my_feature_id_1_2_dest", + }, + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + ), + ], + ), + ( + { + "my_entity_type_id_1": ["my_feature_id_1_1", "my_feature_id_1_2"], + "my_entity_type_id_2": ["my_feature_id_2_1", "my_feature_id_2_2"], + }, + { + f"{_TEST_FEATURESTORE_NAME}/entityTypes/my_entity_type_id_1/features/my_feature_id_1_1": "my_feature_id_1_1_dest", + f"{_TEST_FEATURESTORE_NAME}/entityTypes/my_entity_type_id_2/features/my_feature_id_2_1": "my_feature_id_2_1_dest", + }, + [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + feature_destination_fields={ + "my_feature_id_1_1": "my_feature_id_1_1_dest" + }, + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + feature_destination_fields={ + "my_feature_id_2_1": "my_feature_id_2_1_dest" + }, + ), + ], + ), + ], + ) + def test_validate_and_get_batch_read_feature_values_request( + self, + serving_feature_ids, + feature_destination_fields, + expected_entity_type_specs, + ): + + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + expected_batch_read_feature_values_request = gca_featurestore_service.BatchReadFeatureValuesRequest( + featurestore=my_featurestore.resource_name, + destination=gca_featurestore_service.FeatureValueDestination( + bigquery_destination=_TEST_BQ_DESTINATION, + ), + entity_type_specs=expected_entity_type_specs, + ) + assert ( + expected_batch_read_feature_values_request + == my_featurestore._validate_and_get_batch_read_feature_values_request( + serving_feature_ids=serving_feature_ids, + destination=_TEST_BQ_DESTINATION, + feature_destination_fields=feature_destination_fields, + ) + ) + + @pytest.mark.usefixtures("get_featurestore_mock") + def test_validate_and_get_batch_read_feature_values_request_with_read_instances( + self, + ): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + expected_entity_type_specs = [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + ), + ] + expected_batch_read_feature_values_request = gca_featurestore_service.BatchReadFeatureValuesRequest( + featurestore=my_featurestore.resource_name, + destination=gca_featurestore_service.FeatureValueDestination( + bigquery_destination=_TEST_BQ_DESTINATION, + ), + entity_type_specs=expected_entity_type_specs, + bigquery_read_instances=_TEST_BQ_SOURCE, + ) + assert ( + expected_batch_read_feature_values_request + == my_featurestore._validate_and_get_batch_read_feature_values_request( + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, + destination=_TEST_BQ_DESTINATION, + read_instances=_TEST_BQ_SOURCE, + ) + ) + + @pytest.mark.usefixtures("get_featurestore_mock") + @pytest.mark.parametrize( + "read_instances, expected", + [ + (_TEST_BQ_SOURCE_URI, _TEST_BQ_SOURCE), + (_TEST_GCS_CSV_SOURCE_URIS, _TEST_CSV_SOURCE), + (_TEST_GCS_CSV_SOURCE_URI, _TEST_CSV_SOURCE), + ], + ) + def test_get_read_instances(self, read_instances, expected): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + assert expected == my_featurestore._get_read_instances( + read_instances=read_instances + ) + + @pytest.mark.usefixtures("get_featurestore_mock") + @pytest.mark.parametrize( + "read_instances", + [[1, 2, 3, 4, 5], 1, (_TEST_GCS_CSV_SOURCE_URI, _TEST_GCS_CSV_SOURCE_URI)], + ) + def test_get_read_instances_with_raise_typeerror(self, read_instances): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + with pytest.raises(TypeError): + my_featurestore._get_read_instances(read_instances=read_instances) + + @pytest.mark.usefixtures("get_featurestore_mock") + @pytest.mark.parametrize( + "read_instances", + [ + "gcs://my_bucket/my_file_1.csv", + "bigquery://my_bucket/my_file_1.csv", + "my_bucket/my_file_1.csv", + [_TEST_BQ_SOURCE_URI], + ], + ) + def test_get_read_instances_with_raise_valueerror(self, read_instances): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + with pytest.raises(ValueError): + my_featurestore._get_read_instances(read_instances=read_instances) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_featurestore_mock") + def test_batch_serve_to_bq(self, batch_read_feature_values_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + + expected_entity_type_specs = [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + ), + ] + + expected_batch_read_feature_values_request = gca_featurestore_service.BatchReadFeatureValuesRequest( + featurestore=my_featurestore.resource_name, + destination=gca_featurestore_service.FeatureValueDestination( + bigquery_destination=_TEST_BQ_DESTINATION, + ), + entity_type_specs=expected_entity_type_specs, + ) + + my_featurestore.batch_serve_to_bq( + bq_destination_output_uri=_TEST_BQ_DESTINATION_URI, + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, + sync=sync, + ) + + if not sync: + my_featurestore.wait() + + batch_read_feature_values_mock.assert_called_once_with( + request=expected_batch_read_feature_values_request, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_featurestore_mock") + def test_batch_serve_to_gcs(self, batch_read_feature_values_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + + expected_entity_type_specs = [ + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_1", + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], + ), + _get_entity_type_spec_proto_with_feature_ids( + entity_type_id="my_entity_type_id_2", + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], + ), + ] + + expected_batch_read_feature_values_request = gca_featurestore_service.BatchReadFeatureValuesRequest( + featurestore=my_featurestore.resource_name, + destination=gca_featurestore_service.FeatureValueDestination( + tfrecord_destination=_TEST_TFRECORD_DESTINATION, + ), + entity_type_specs=expected_entity_type_specs, + ) + + my_featurestore.batch_serve_to_gcs( + gcs_destination_output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX, + gcs_destination_type=_TEST_GCS_DESTINATION_TYPE_TFRECORD, + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, + sync=sync, + ) + + if not sync: + my_featurestore.wait() + + batch_read_feature_values_mock.assert_called_once_with( + request=expected_batch_read_feature_values_request, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_featurestore_mock") + def test_batch_serve_to_gcs_with_invalid_gcs_destination_type(self): + + aiplatform.init(project=_TEST_PROJECT) + + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_NAME + ) + with pytest.raises(ValueError): + my_featurestore.batch_serve_to_gcs( + gcs_destination_output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX, + gcs_destination_type=_TEST_GCS_DESTINATION_TYPE_INVALID, + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, + ) + class TestEntityType: def setup_method(self): diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index d4840609b10..b47eb684d8b 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -19,7 +19,6 @@ import pytest from typing import Callable, Dict, Optional import datetime -from decimal import Decimal from google.protobuf import timestamp_pb2 @@ -326,28 +325,8 @@ def test_client_w_override_select_version(): @pytest.mark.parametrize( "year,month,day,hour,minute,second,microsecond,expected_seconds,expected_nanos", [ - ( - 2021, - 12, - 23, - 23, - 59, - 59, - 999999, - 1640303999, - int(str(Decimal(1640303999.999999)).split(".")[1][:9]), - ), - ( - 2013, - 1, - 1, - 1, - 1, - 1, - 199999, - 1357002061, - int(str(Decimal(1357002061.199999)).split(".")[1][:9]), - ), + (2021, 12, 23, 23, 59, 59, 999999, 1640303999, 999000000,), + (2013, 1, 1, 1, 1, 1, 199999, 1357002061, 199000000,), ], ) def test_get_timestamp_proto( @@ -369,7 +348,6 @@ def test_get_timestamp_proto( minute=minute, second=second, microsecond=microsecond, - tzinfo=datetime.timezone.utc, ) true_timestamp_proto = timestamp_pb2.Timestamp( seconds=expected_seconds, nanos=expected_nanos From 9289f2d3ce424f3f9754a3dd23883e25dec1300f Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Thu, 27 Jan 2022 07:56:05 -0800 Subject: [PATCH 3/6] feat: enable ingest from pd.DataFrame (#977) * feat: enable ingest from pd.DataFrame * fix: remove bq create_dataset, docstrings, mocks * fix: e2e_base project * fix: delete two optional args, add note for temp bq dataset, revert deleting bq dataset create, add featurestore_extra_require, update ic tests to use online read to validate feature value ingestionfrom df * fix: add a comment of call complete upon ingestion, update unit tests --- .../aiplatform/featurestore/entity_type.py | 192 ++++++++++++++---- setup.py | 2 + tests/system/aiplatform/test_featurestore.py | 174 ++++++++++++++-- tests/unit/aiplatform/test_featurestores.py | 161 ++++++++++++++- 4 files changed, 471 insertions(+), 58 deletions(-) diff --git a/google/cloud/aiplatform/featurestore/entity_type.py b/google/cloud/aiplatform/featurestore/entity_type.py index 6e993f26b5b..8a85b1aa7ad 100644 --- a/google/cloud/aiplatform/featurestore/entity_type.py +++ b/google/cloud/aiplatform/featurestore/entity_type.py @@ -17,6 +17,7 @@ import datetime from typing import Dict, List, Optional, Sequence, Tuple, Union +import uuid from google.auth import credentials as auth_credentials from google.protobuf import field_mask_pb2 @@ -34,6 +35,7 @@ from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import featurestore_utils +from google.cloud import bigquery _LOGGER = base.Logger(__name__) _ALL_FEATURE_IDS = "*" @@ -795,23 +797,16 @@ def _validate_and_get_import_feature_values_request( If not provided, the source column need to be the same as the Feature ID. Example: + feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - - In case all features' source field and ID match: - feature_source_fields = None or {} - - In case all features' source field and ID do not match: - feature_source_fields = { + feature_source_fields = { 'my_feature_id_1': 'my_feature_id_1_source_field', - 'my_feature_id_2': 'my_feature_id_2_source_field', - 'my_feature_id_3': 'my_feature_id_3_source_field', - } + } + + Note: + The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field', + The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'. - In case some features' source field and ID do not match: - feature_source_fields = { - 'my_feature_id_1': 'my_feature_id_1_source_field', - } entity_id_field (str): Optional. Source column that holds entity IDs. If not provided, entity IDs are extracted from the column named ``entity_id``. @@ -954,23 +949,16 @@ def ingest_from_bq( If not provided, the source column need to be the same as the Feature ID. Example: + feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - - In case all features' source field and ID match: - feature_source_fields = None or {} - - In case all features' source field and ID do not match: - feature_source_fields = { + feature_source_fields = { 'my_feature_id_1': 'my_feature_id_1_source_field', - 'my_feature_id_2': 'my_feature_id_2_source_field', - 'my_feature_id_3': 'my_feature_id_3_source_field', - } + } + + Note: + The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field', + The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'. - In case some features' source field and ID do not match: - feature_source_fields = { - 'my_feature_id_1': 'my_feature_id_1_source_field', - } entity_id_field (str): Optional. Source column that holds entity IDs. If not provided, entity IDs are extracted from the column named ``entity_id``. @@ -1000,6 +988,7 @@ def ingest_from_bq( EntityType - The entityType resource object with feature values imported. """ + bigquery_source = gca_io.BigQuerySource(input_uri=bq_source_uri) import_feature_values_request = self._validate_and_get_import_feature_values_request( @@ -1065,23 +1054,16 @@ def ingest_from_gcs( If not provided, the source column need to be the same as the Feature ID. Example: + feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] - - In case all features' source field and ID match: - feature_source_fields = None or {} - - In case all features' source field and ID do not match: - feature_source_fields = { + feature_source_fields = { 'my_feature_id_1': 'my_feature_id_1_source_field', - 'my_feature_id_2': 'my_feature_id_2_source_field', - 'my_feature_id_3': 'my_feature_id_3_source_field', - } + } + + Note: + The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field', + The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'. - In case some features' source field and ID do not match: - feature_source_fields = { - 'my_feature_id_1': 'my_feature_id_1_source_field', - } entity_id_field (str): Optional. Source column that holds entity IDs. If not provided, entity IDs are extracted from the column named ``entity_id``. @@ -1146,6 +1128,132 @@ def ingest_from_gcs( request_metadata=request_metadata, ) + def ingest_from_df( + self, + feature_ids: List[str], + feature_time: Union[str, datetime.datetime], + df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd' + feature_source_fields: Optional[Dict[str, str]] = None, + entity_id_field: Optional[str] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + ) -> "EntityType": + """Ingest feature values from DataFrame. + + Note: + Calling this method will automatically create and delete a temporary + bigquery dataset in the same GCP project, which will be used + as the intermediary storage for ingesting feature values + from dataframe to featurestore. + + The call will return upon ingestion completes, where the + feature values will be ingested into the entity_type. + + Args: + feature_ids (List[str]): + Required. IDs of the Feature to import values + of. The Features must exist in the target + EntityType, or the request will fail. + feature_time (Union[str, datetime.datetime]): + Required. The feature_time can be one of: + - The source column that holds the Feature + timestamp for all Feature values in each entity. + + Note: + The dtype of the source column should be `datetime64`. + + - A single Feature timestamp for all entities + being imported. The timestamp must not have + higher than millisecond precision. + + Example: + feature_time = datetime.datetime(year=2022, month=1, day=1, hour=11, minute=59, second=59) + or + feature_time_str = datetime.datetime.now().isoformat(sep=" ", timespec="milliseconds") + feature_time = datetime.datetime.strptime(feature_time_str, "%Y-%m-%d %H:%M:%S.%f") + + df_source (pd.DataFrame): + Required. Pandas DataFrame containing the source data for ingestion. + feature_source_fields (Dict[str, str]): + Optional. User defined dictionary to map ID of the Feature for importing values + of to the source column for getting the Feature values from. + + Specify the features whose ID and source column are not the same. + If not provided, the source column need to be the same as the Feature ID. + + Example: + feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3'] + + feature_source_fields = { + 'my_feature_id_1': 'my_feature_id_1_source_field', + } + + Note: + The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field', + The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'. + + entity_id_field (str): + Optional. Source column that holds entity IDs. If not provided, entity + IDs are extracted from the column named ``entity_id``. + request_metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as metadata. + + Returns: + EntityType - The entityType resource object with feature values imported. + + """ + try: + import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery' + except ImportError: + raise ImportError( + f"Pyarrow is not installed. Please install pyarrow to use " + f"{self.ingest_from_df.__name__}" + ) + + bigquery_client = bigquery.Client( + project=self.project, credentials=self.credentials + ) + + entity_type_name_components = self._parse_resource_name(self.resource_name) + featurestore_id, entity_type_id = ( + entity_type_name_components["featurestore"], + entity_type_name_components["entity_type"], + ) + + temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace( + "-", "_" + ) + temp_bq_dataset_id = f"{initializer.global_config.project}.{temp_bq_dataset_name}"[ + :1024 + ] + temp_bq_table_id = f"{temp_bq_dataset_id}.{entity_type_id}" + + temp_bq_dataset = bigquery.Dataset(dataset_ref=temp_bq_dataset_id) + temp_bq_dataset.location = self.location + + temp_bq_dataset = bigquery_client.create_dataset(temp_bq_dataset) + + try: + job = bigquery_client.load_table_from_dataframe( + dataframe=df_source, destination=temp_bq_table_id + ) + job.result() + + entity_type_obj = self.ingest_from_bq( + feature_ids=feature_ids, + feature_time=feature_time, + bq_source_uri=f"bq://{temp_bq_table_id}", + feature_source_fields=feature_source_fields, + entity_id_field=entity_id_field, + request_metadata=request_metadata, + ) + + finally: + bigquery_client.delete_dataset( + dataset=temp_bq_dataset.dataset_id, delete_contents=True, + ) + + return entity_type_obj + @staticmethod def _instantiate_featurestore_online_client( location: Optional[str] = None, diff --git a/setup.py b/setup.py index 2b8c58f033e..011da8aea8e 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "werkzeug >= 2.0.0", "tensorflow >=2.4.0", ] +featurestore_extra_require = ["pandas >= 1.0.0", "pyarrow >= 6.0.1"] full_extra_require = list( set( @@ -54,6 +55,7 @@ + metadata_extra_require + xai_extra_require + lit_extra_require + + featurestore_extra_require ) ) testing_extra_require = ( diff --git a/tests/system/aiplatform/test_featurestore.py b/tests/system/aiplatform/test_featurestore.py index b67dec6883f..03070eee3c9 100644 --- a/tests/system/aiplatform/test_featurestore.py +++ b/tests/system/aiplatform/test_featurestore.py @@ -15,6 +15,7 @@ # limitations under the License. # +import datetime import logging import pytest @@ -201,22 +202,17 @@ def test_ingest_feature_values(self, shared_state, caplog): gcs_source_uris=_TEST_USERS_ENTITY_TYPE_GCS_SRC, gcs_source_type="avro", entity_id_field="user_id", - worker_count=2, + worker_count=1, ) assert "EntityType feature values imported." in caplog.text caplog.clear() - def test_batch_create_features_and_ingest_feature_values( - self, shared_state, caplog - ): - + def test_batch_create_features(self, shared_state): assert shared_state["movie_entity_type"] movie_entity_type = shared_state["movie_entity_type"] - caplog.set_level(logging.INFO) - aiplatform.init( project=e2e_base._PROJECT, location=e2e_base._LOCATION, ) @@ -232,21 +228,171 @@ def test_batch_create_features_and_ingest_feature_values( movie_entity_type.batch_create_features(feature_configs=movie_feature_configs) - movie_entity_type.ingest_from_gcs( + list_movie_features = movie_entity_type.list_features() + assert len(list_movie_features) == 3 + + def test_ingest_feature_values_from_df_using_feature_time_column_and_online_read_multiple_entities( + self, shared_state, caplog + ): + + assert shared_state["movie_entity_type"] + movie_entity_type = shared_state["movie_entity_type"] + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=e2e_base._PROJECT, location=e2e_base._LOCATION, + ) + + read_feature_ids = ["average_rating", "title", "genres"] + + movie_entity_views_df_before_ingest = movie_entity_type.read( + entity_ids=["movie_01", "movie_02"], feature_ids=read_feature_ids, + ) + expected_data_before_ingest = [ + { + "entity_id": "movie_01", + "average_rating": None, + "title": None, + "genres": None, + }, + { + "entity_id": "movie_02", + "average_rating": None, + "title": None, + "genres": None, + }, + ] + expected_movie_entity_views_df_before_ingest = pd.DataFrame( + data=expected_data_before_ingest, columns=read_feature_ids + ) + + movie_entity_views_df_before_ingest.equals( + expected_movie_entity_views_df_before_ingest + ) + + movies_df = pd.DataFrame( + data=[ + { + "movie_id": "movie_01", + "average_rating": 4.9, + "title": "The Shawshank Redemption", + "genres": "Drama", + "update_time": "2021-08-20 20:44:11.094375+00:00", + }, + { + "movie_id": "movie_02", + "average_rating": 4.2, + "title": "The Shining", + "genres": "Horror", + "update_time": "2021-08-20 20:44:11.094375+00:00", + }, + ], + columns=["movie_id", "average_rating", "title", "genres", "update_time"], + ) + movies_df = movies_df.astype({"update_time": "datetime64"}) + feature_time_column = "update_time" + + movie_entity_type.ingest_from_df( feature_ids=[ _TEST_MOVIE_TITLE_FEATURE_ID, _TEST_MOVIE_GENRES_FEATURE_ID, _TEST_MOVIE_AVERAGE_RATING_FEATURE_ID, ], - feature_time="update_time", - gcs_source_uris=_TEST_MOVIES_ENTITY_TYPE_GCS_SRC, - gcs_source_type="avro", + feature_time=feature_time_column, + df_source=movies_df, entity_id_field="movie_id", - worker_count=2, ) - list_movie_features = movie_entity_type.list_features() - assert len(list_movie_features) == 3 + movie_entity_views_df_after_ingest = movie_entity_type.read( + entity_ids=["movie_01", "movie_02"], feature_ids=read_feature_ids, + ) + expected_data_after_ingest = [ + { + "movie_id": "movie_01", + "average_rating": 4.9, + "title": "The Shawshank Redemption", + "genres": "Drama", + }, + { + "movie_id": "movie_02", + "average_rating": 4.2, + "title": "The Shining", + "genres": "Horror", + }, + ] + expected_movie_entity_views_df_after_ingest = pd.DataFrame( + data=expected_data_after_ingest, columns=read_feature_ids + ) + + movie_entity_views_df_after_ingest.equals( + expected_movie_entity_views_df_after_ingest + ) + + assert "EntityType feature values imported." in caplog.text + caplog.clear() + + def test_ingest_feature_values_from_df_using_feature_time_datetime_and_online_read_single_entity( + self, shared_state, caplog + ): + assert shared_state["movie_entity_type"] + movie_entity_type = shared_state["movie_entity_type"] + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=e2e_base._PROJECT, location=e2e_base._LOCATION, + ) + + movies_df = pd.DataFrame( + data=[ + { + "movie_id": "movie_03", + "average_rating": 4.5, + "title": "Cinema Paradiso", + "genres": "Romance", + }, + { + "movie_id": "movie_04", + "average_rating": 4.6, + "title": "The Dark Knight", + "genres": "Action", + }, + ], + columns=["movie_id", "average_rating", "title", "genres"], + ) + + feature_time_datetime_str = datetime.datetime.now().isoformat( + sep=" ", timespec="milliseconds" + ) + feature_time_datetime = datetime.datetime.strptime( + feature_time_datetime_str, "%Y-%m-%d %H:%M:%S.%f" + ) + + movie_entity_type.ingest_from_df( + feature_ids=[ + _TEST_MOVIE_TITLE_FEATURE_ID, + _TEST_MOVIE_GENRES_FEATURE_ID, + _TEST_MOVIE_AVERAGE_RATING_FEATURE_ID, + ], + feature_time=feature_time_datetime, + df_source=movies_df, + entity_id_field="movie_id", + ) + + movie_entity_views_df_avg_rating = movie_entity_type.read( + entity_ids="movie_04", feature_ids="average_rating", + ) + expected_data_avg_rating = [ + {"movie_id": "movie_04", "average_rating": 4.6}, + ] + expected_movie_entity_views_df_avg_rating = pd.DataFrame( + data=expected_data_avg_rating, columns=["average_rating"] + ) + + movie_entity_views_df_avg_rating.equals( + expected_movie_entity_views_df_avg_rating + ) assert "EntityType feature values imported." in caplog.text diff --git a/tests/unit/aiplatform/test_featurestores.py b/tests/unit/aiplatform/test_featurestores.py index 97cec0056f0..449d0348c31 100644 --- a/tests/unit/aiplatform/test_featurestores.py +++ b/tests/unit/aiplatform/test_featurestores.py @@ -19,13 +19,14 @@ import pytest import datetime import pandas as pd +import uuid from unittest import mock from importlib import reload -from unittest.mock import patch +from unittest.mock import MagicMock, patch from google.api_core import operation -from google.protobuf import field_mask_pb2 +from google.protobuf import field_mask_pb2, timestamp_pb2 from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -51,11 +52,17 @@ types as gca_types, ) +from google.cloud import bigquery + # project _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_FEATURE_TIME_DATETIME = datetime.datetime( + year=2022, month=1, day=1, hour=11, minute=59, second=59 +) + # featurestore _TEST_FEATURESTORE_ID = "featurestore_id" _TEST_FEATURESTORE_NAME = f"{_TEST_PARENT}/featurestores/{_TEST_FEATURESTORE_ID}" @@ -321,6 +328,57 @@ def _get_entity_view_proto(entity_id, feature_value_types, feature_values): return entity_view_proto +def uuid_mock(): + return uuid.UUID(int=1) + + +# All BigQuery Mocks +@pytest.fixture +def bq_client_mock(): + mock = MagicMock(bigquery.client.Client) + yield mock + + +@pytest.fixture +def bq_dataset_mock(): + mock = MagicMock(bigquery.dataset.Dataset) + yield mock + + +@pytest.fixture +def bq_init_client_mock(bq_client_mock): + with patch.object(bigquery, "Client") as bq_init_client_mock: + bq_init_client_mock.return_value = bq_client_mock + yield bq_init_client_mock + + +@pytest.fixture +def bq_init_dataset_mock(bq_dataset_mock): + with patch.object(bigquery, "Dataset") as bq_init_dataset_mock: + bq_init_dataset_mock.return_value = bq_dataset_mock + yield bq_init_dataset_mock + + +@pytest.fixture +def bq_create_dataset_mock(bq_init_client_mock): + with patch.object(bigquery.Client, "create_dataset") as bq_create_dataset_mock: + yield bq_create_dataset_mock + + +@pytest.fixture +def bq_load_table_from_dataframe_mock(bq_init_client_mock): + with patch.object( + bigquery.Client, "load_table_from_dataframe" + ) as bq_load_table_from_dataframe_mock: + yield bq_load_table_from_dataframe_mock + + +@pytest.fixture +def bq_delete_dataset_mock(bq_init_client_mock): + with patch.object(bigquery.Client, "delete_dataset") as bq_delete_dataset_mock: + yield bq_delete_dataset_mock + + # All Featurestore Mocks @pytest.fixture def get_featurestore_mock(): @@ -1552,6 +1610,105 @@ def test_ingest_from_gcs_with_invalid_gcs_source_type(self): gcs_source_type=_TEST_GCS_SOURCE_TYPE_INVALID, ) + @pytest.mark.usefixtures( + "get_entity_type_mock", + "bq_init_client_mock", + "bq_init_dataset_mock", + "bq_create_dataset_mock", + "bq_load_table_from_dataframe_mock", + "bq_delete_dataset_mock", + ) + @patch("uuid.uuid4", uuid_mock) + def test_ingest_from_df_using_column(self, import_feature_values_mock): + + aiplatform.init(project=_TEST_PROJECT) + + my_entity_type = aiplatform.EntityType(entity_type_name=_TEST_ENTITY_TYPE_NAME) + df_source = pd.DataFrame() + my_entity_type.ingest_from_df( + feature_ids=_TEST_IMPORTING_FEATURE_IDS, + feature_time=_TEST_FEATURE_TIME_FIELD, + df_source=df_source, + feature_source_fields=_TEST_IMPORTING_FEATURE_SOURCE_FIELDS, + ) + expected_temp_bq_dataset_name = f"temp_{_TEST_FEATURESTORE_ID}_{uuid.uuid4()}".replace( + "-", "_" + ) + expecte_temp_bq_dataset_id = f"{initializer.global_config.project}.{expected_temp_bq_dataset_name}"[ + :1024 + ] + expected_temp_bq_table_id = ( + f"{expecte_temp_bq_dataset_id}.{_TEST_ENTITY_TYPE_ID}" + ) + + true_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest( + entity_type=_TEST_ENTITY_TYPE_NAME, + feature_specs=[ + gca_featurestore_service.ImportFeatureValuesRequest.FeatureSpec( + id="my_feature_id_1", source_field="my_feature_id_1_source_field" + ), + ], + bigquery_source=gca_io.BigQuerySource( + input_uri=f"bq://{expected_temp_bq_table_id}" + ), + feature_time_field=_TEST_FEATURE_TIME_FIELD, + ) + + import_feature_values_mock.assert_called_once_with( + request=true_import_feature_values_request, metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures( + "get_entity_type_mock", + "bq_init_client_mock", + "bq_init_dataset_mock", + "bq_create_dataset_mock", + "bq_load_table_from_dataframe_mock", + "bq_delete_dataset_mock", + ) + @patch("uuid.uuid4", uuid_mock) + def test_ingest_from_df_using_datetime(self, import_feature_values_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_entity_type = aiplatform.EntityType(entity_type_name=_TEST_ENTITY_TYPE_NAME) + df_source = pd.DataFrame() + my_entity_type.ingest_from_df( + feature_ids=_TEST_IMPORTING_FEATURE_IDS, + feature_time=_TEST_FEATURE_TIME_DATETIME, + df_source=df_source, + feature_source_fields=_TEST_IMPORTING_FEATURE_SOURCE_FIELDS, + ) + + expected_temp_bq_dataset_name = f"temp_{_TEST_FEATURESTORE_ID}_{uuid.uuid4()}".replace( + "-", "_" + ) + expecte_temp_bq_dataset_id = f"{initializer.global_config.project}.{expected_temp_bq_dataset_name}"[ + :1024 + ] + expected_temp_bq_table_id = ( + f"{expecte_temp_bq_dataset_id}.{_TEST_ENTITY_TYPE_ID}" + ) + + timestamp_proto = timestamp_pb2.Timestamp() + timestamp_proto.FromDatetime(_TEST_FEATURE_TIME_DATETIME) + + true_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest( + entity_type=_TEST_ENTITY_TYPE_NAME, + feature_specs=[ + gca_featurestore_service.ImportFeatureValuesRequest.FeatureSpec( + id="my_feature_id_1", source_field="my_feature_id_1_source_field" + ), + ], + bigquery_source=gca_io.BigQuerySource( + input_uri=f"bq://{expected_temp_bq_table_id}" + ), + feature_time=timestamp_proto, + ) + + import_feature_values_mock.assert_called_once_with( + request=true_import_feature_values_request, metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock") def test_read_single_entity(self, read_feature_values_mock): aiplatform.init(project=_TEST_PROJECT) From 0ca374769f922fd427c5b6f58c9ce1ab40f18d18 Mon Sep 17 00:00:00 2001 From: Ivan Cheung Date: Thu, 27 Jan 2022 13:46:18 -0500 Subject: [PATCH 4/6] fix: Fixed integration test for model.upload (#975) Fixed bug in integration test introduced by https://github.com/googleapis/python-aiplatform/pull/952 --- tests/system/aiplatform/test_model_upload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/aiplatform/test_model_upload.py b/tests/system/aiplatform/test_model_upload.py index 625a3dc2c4b..70f78aec602 100644 --- a/tests/system/aiplatform/test_model_upload.py +++ b/tests/system/aiplatform/test_model_upload.py @@ -72,5 +72,5 @@ def test_upload_and_deploy_xgboost_model(self, shared_state): labels={"my_label": "updated"}, ) assert model.display_name == "new_name" - assert model.display_name == "new_description" + assert model.description == "new_description" assert model.labels == {"my_label": "updated"} From a8149233bcd857e75700c6ec7d29c0aabf1687c1 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Mon, 31 Jan 2022 14:02:23 -0800 Subject: [PATCH 5/6] feat: add dedicated_resources to DeployedIndex message in aiplatform v1 index_endpoint.proto chore: sort imports (#990) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add dedicated_resources to DeployedIndex message in aiplatform v1 index_endpoint.proto chore: sort imports PiperOrigin-RevId: 425394497 Source-Link: https://github.com/googleapis/googleapis/commit/bd97e467afc78c328426d2f06fa4d7cc2d0bfc51 Source-Link: https://github.com/googleapis/googleapis-gen/commit/13eed11051e1cf9f9ab43d174f23d35ffb32941c Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMTNlZWQxMTA1MWUxY2Y5ZjlhYjQzZDE3NGYyM2QzNWZmYjMyOTQxYyJ9 * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md Co-authored-by: Owl Bot --- .../definition_v1/types/automl_tables.py | 16 +- .../definition_v1beta1/types/automl_tables.py | 16 +- .../services/dataset_service/async_client.py | 58 +- .../services/dataset_service/client.py | 147 +++-- .../dataset_service/transports/base.py | 1 - .../dataset_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/endpoint_service/async_client.py | 52 +- .../services/endpoint_service/client.py | 141 +++-- .../endpoint_service/transports/base.py | 1 - .../endpoint_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../async_client.py | 51 +- .../client.py | 131 ++-- .../transports/base.py | 1 - .../transports/grpc.py | 5 +- .../transports/grpc_asyncio.py | 5 +- .../featurestore_service/async_client.py | 78 ++- .../services/featurestore_service/client.py | 167 +++-- .../featurestore_service/transports/base.py | 1 - .../featurestore_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../index_endpoint_service/async_client.py | 54 +- .../services/index_endpoint_service/client.py | 143 +++-- .../index_endpoint_service/transports/base.py | 1 - .../index_endpoint_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/index_service/async_client.py | 48 +- .../services/index_service/client.py | 137 +++-- .../services/index_service/transports/base.py | 1 - .../services/index_service/transports/grpc.py | 7 +- .../index_service/transports/grpc_asyncio.py | 7 +- .../services/job_service/async_client.py | 94 ++- .../services/job_service/client.py | 183 +++--- .../services/job_service/transports/base.py | 1 - .../services/job_service/transports/grpc.py | 7 +- .../job_service/transports/grpc_asyncio.py | 7 +- .../services/metadata_service/async_client.py | 100 ++- .../services/metadata_service/client.py | 189 +++--- .../metadata_service/transports/base.py | 1 - .../metadata_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../migration_service/async_client.py | 42 +- .../services/migration_service/client.py | 153 +++-- .../migration_service/transports/base.py | 1 - .../migration_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/model_service/async_client.py | 58 +- .../services/model_service/client.py | 147 +++-- .../services/model_service/transports/base.py | 1 - .../services/model_service/transports/grpc.py | 7 +- .../model_service/transports/grpc_asyncio.py | 7 +- .../services/pipeline_service/async_client.py | 58 +- .../services/pipeline_service/client.py | 147 +++-- .../pipeline_service/transports/base.py | 1 - .../pipeline_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../prediction_service/async_client.py | 44 +- .../services/prediction_service/client.py | 139 +++-- .../prediction_service/transports/base.py | 1 - .../prediction_service/transports/grpc.py | 5 +- .../transports/grpc_asyncio.py | 5 +- .../specialist_pool_service/async_client.py | 48 +- .../specialist_pool_service/client.py | 137 +++-- .../transports/base.py | 1 - .../transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../tensorboard_service/async_client.py | 103 +++- .../services/tensorboard_service/client.py | 183 +++--- .../tensorboard_service/transports/base.py | 1 - .../tensorboard_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/vizier_service/async_client.py | 64 +- .../services/vizier_service/client.py | 153 +++-- .../vizier_service/transports/base.py | 1 - .../vizier_service/transports/grpc.py | 7 +- .../vizier_service/transports/grpc_asyncio.py | 7 +- .../cloud/aiplatform_v1/types/annotation.py | 4 +- .../aiplatform_v1/types/annotation_spec.py | 4 +- google/cloud/aiplatform_v1/types/artifact.py | 4 +- google/cloud/aiplatform_v1/types/context.py | 4 +- .../cloud/aiplatform_v1/types/custom_job.py | 1 + google/cloud/aiplatform_v1/types/data_item.py | 4 +- .../aiplatform_v1/types/data_labeling_job.py | 6 +- google/cloud/aiplatform_v1/types/dataset.py | 7 +- google/cloud/aiplatform_v1/types/endpoint.py | 6 +- .../cloud/aiplatform_v1/types/entity_type.py | 4 +- google/cloud/aiplatform_v1/types/execution.py | 4 +- .../cloud/aiplatform_v1/types/explanation.py | 8 +- .../types/explanation_metadata.py | 4 +- google/cloud/aiplatform_v1/types/feature.py | 6 +- .../cloud/aiplatform_v1/types/featurestore.py | 13 +- .../aiplatform_v1/types/index_endpoint.py | 12 + google/cloud/aiplatform_v1/types/model.py | 10 +- .../types/model_deployment_monitoring_job.py | 12 +- google/cloud/aiplatform_v1/types/study.py | 3 +- .../cloud/aiplatform_v1/types/tensorboard.py | 6 +- .../aiplatform_v1/types/tensorboard_run.py | 6 +- .../types/tensorboard_time_series.py | 6 +- .../aiplatform_v1/types/training_pipeline.py | 14 +- .../services/dataset_service/async_client.py | 58 +- .../services/dataset_service/client.py | 147 +++-- .../dataset_service/transports/base.py | 1 - .../dataset_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/endpoint_service/async_client.py | 52 +- .../services/endpoint_service/client.py | 141 +++-- .../endpoint_service/transports/base.py | 1 - .../endpoint_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../async_client.py | 51 +- .../client.py | 131 ++-- .../transports/base.py | 1 - .../transports/grpc.py | 5 +- .../transports/grpc_asyncio.py | 5 +- .../featurestore_service/async_client.py | 78 ++- .../services/featurestore_service/client.py | 167 +++-- .../featurestore_service/transports/base.py | 1 - .../featurestore_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../index_endpoint_service/async_client.py | 54 +- .../services/index_endpoint_service/client.py | 143 +++-- .../index_endpoint_service/transports/base.py | 1 - .../index_endpoint_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/index_service/async_client.py | 48 +- .../services/index_service/client.py | 137 +++-- .../services/index_service/transports/base.py | 1 - .../services/index_service/transports/grpc.py | 7 +- .../index_service/transports/grpc_asyncio.py | 7 +- .../services/job_service/async_client.py | 94 ++- .../services/job_service/client.py | 183 +++--- .../services/job_service/transports/base.py | 1 - .../services/job_service/transports/grpc.py | 7 +- .../job_service/transports/grpc_asyncio.py | 7 +- .../services/metadata_service/async_client.py | 100 ++- .../services/metadata_service/client.py | 189 +++--- .../metadata_service/transports/base.py | 1 - .../metadata_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../migration_service/async_client.py | 42 +- .../services/migration_service/client.py | 153 +++-- .../migration_service/transports/base.py | 1 - .../migration_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/model_service/async_client.py | 58 +- .../services/model_service/client.py | 147 +++-- .../services/model_service/transports/base.py | 1 - .../services/model_service/transports/grpc.py | 7 +- .../model_service/transports/grpc_asyncio.py | 7 +- .../services/pipeline_service/async_client.py | 58 +- .../services/pipeline_service/client.py | 147 +++-- .../pipeline_service/transports/base.py | 1 - .../pipeline_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../prediction_service/async_client.py | 44 +- .../services/prediction_service/client.py | 133 ++-- .../prediction_service/transports/base.py | 1 - .../prediction_service/transports/grpc.py | 5 +- .../transports/grpc_asyncio.py | 5 +- .../specialist_pool_service/async_client.py | 48 +- .../specialist_pool_service/client.py | 137 +++-- .../transports/base.py | 1 - .../transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../tensorboard_service/async_client.py | 103 +++- .../services/tensorboard_service/client.py | 183 +++--- .../tensorboard_service/transports/base.py | 1 - .../tensorboard_service/transports/grpc.py | 7 +- .../transports/grpc_asyncio.py | 7 +- .../services/vizier_service/async_client.py | 58 +- .../services/vizier_service/client.py | 147 +++-- .../vizier_service/transports/base.py | 1 - .../vizier_service/transports/grpc.py | 7 +- .../vizier_service/transports/grpc_asyncio.py | 7 +- .../aiplatform_v1beta1/types/annotation.py | 4 +- .../types/annotation_spec.py | 4 +- .../aiplatform_v1beta1/types/artifact.py | 4 +- .../cloud/aiplatform_v1beta1/types/context.py | 4 +- .../aiplatform_v1beta1/types/data_item.py | 4 +- .../types/data_labeling_job.py | 6 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 7 +- .../aiplatform_v1beta1/types/endpoint.py | 6 +- .../aiplatform_v1beta1/types/entity_type.py | 4 +- .../aiplatform_v1beta1/types/execution.py | 4 +- .../aiplatform_v1beta1/types/explanation.py | 8 +- .../types/explanation_metadata.py | 4 +- .../cloud/aiplatform_v1beta1/types/feature.py | 6 +- .../aiplatform_v1beta1/types/featurestore.py | 4 +- .../types/model_deployment_monitoring_job.py | 7 +- .../aiplatform_v1beta1/types/tensorboard.py | 6 +- .../types/tensorboard_run.py | 6 +- .../types/tensorboard_time_series.py | 6 +- .../types/training_pipeline.py | 14 +- .../aiplatform_v1/test_dataset_service.py | 315 +++++++--- .../aiplatform_v1/test_endpoint_service.py | 266 ++++++-- ...est_featurestore_online_serving_service.py | 217 ++++++- .../test_featurestore_service.py | 454 +++++++++----- .../test_index_endpoint_service.py | 286 ++++++--- .../gapic/aiplatform_v1/test_index_service.py | 236 +++++-- .../gapic/aiplatform_v1/test_job_service.py | 559 +++++++++-------- .../aiplatform_v1/test_metadata_service.py | 572 +++++++++-------- .../aiplatform_v1/test_migration_service.py | 248 ++++++-- .../gapic/aiplatform_v1/test_model_service.py | 313 +++++++--- .../aiplatform_v1/test_pipeline_service.py | 319 +++++++--- .../aiplatform_v1/test_prediction_service.py | 214 ++++++- .../test_specialist_pool_service.py | 252 ++++++-- .../aiplatform_v1/test_tensorboard_service.py | 581 ++++++++++-------- .../aiplatform_v1/test_vizier_service.py | 351 +++++++---- .../test_dataset_service.py | 315 +++++++--- .../test_endpoint_service.py | 266 ++++++-- ...est_featurestore_online_serving_service.py | 217 ++++++- .../test_featurestore_service.py | 454 +++++++++----- .../test_index_endpoint_service.py | 285 ++++++--- .../aiplatform_v1beta1/test_index_service.py | 236 +++++-- .../aiplatform_v1beta1/test_job_service.py | 559 +++++++++-------- .../test_metadata_service.py | 572 +++++++++-------- .../test_migration_service.py | 248 ++++++-- .../aiplatform_v1beta1/test_model_service.py | 313 +++++++--- .../test_pipeline_service.py | 319 +++++++--- .../test_prediction_service.py | 214 ++++++- .../test_specialist_pool_service.py | 252 ++++++-- .../test_tensorboard_service.py | 578 +++++++++-------- .../aiplatform_v1beta1/test_vizier_service.py | 351 +++++++---- 224 files changed, 11938 insertions(+), 5626 deletions(-) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py index 0e42e77a865..64ff650ad06 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py @@ -99,9 +99,9 @@ class AutoMlTablesInputs(proto.Message): operating characteristic (ROC) curve. "minimize-log-loss" - Minimize log loss. "maximize-au-prc" - Maximize the area under - the precision-recall curve. "maximize- - precision-at-recall" - Maximize precision for a - specified + the precision-recall curve. + "maximize-precision-at-recall" - Maximize + precision for a specified recall value. "maximize-recall-at-precision" - Maximize recall for a specified precision value. @@ -109,11 +109,11 @@ class AutoMlTablesInputs(proto.Message): "minimize-log-loss" (default) - Minimize log loss. regression: - "minimize-rmse" (default) - Minimize root- - mean-squared error (RMSE). "minimize-mae" - - Minimize mean-absolute error (MAE). "minimize- - rmsle" - Minimize root-mean-squared log error - (RMSLE). + "minimize-rmse" (default) - Minimize + root-mean-squared error (RMSE). "minimize-mae" + - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared + log error (RMSLE). train_budget_milli_node_hours (int): Required. The train budget of creating this model, expressed in milli node hours i.e. 1,000 diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py index aae2ec93483..44278051901 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py @@ -99,9 +99,9 @@ class AutoMlTablesInputs(proto.Message): operating characteristic (ROC) curve. "minimize-log-loss" - Minimize log loss. "maximize-au-prc" - Maximize the area under - the precision-recall curve. "maximize- - precision-at-recall" - Maximize precision for a - specified + the precision-recall curve. + "maximize-precision-at-recall" - Maximize + precision for a specified recall value. "maximize-recall-at-precision" - Maximize recall for a specified precision value. @@ -109,11 +109,11 @@ class AutoMlTablesInputs(proto.Message): "minimize-log-loss" (default) - Minimize log loss. regression: - "minimize-rmse" (default) - Minimize root- - mean-squared error (RMSE). "minimize-mae" - - Minimize mean-absolute error (MAE). "minimize- - rmsle" - Minimize root-mean-squared log error - (RMSLE). + "minimize-rmse" (default) - Minimize + root-mean-squared error (RMSE). "minimize-mae" + - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared + log error (RMSLE). train_budget_milli_node_hours (int): Required. The train budget of creating this model, expressed in milli node hours i.e. 1,000 diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index 6a6c297f692..163d755b46e 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -129,6 +129,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return DatasetServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> DatasetServiceTransport: """Returns the transport used by the client instance. @@ -234,7 +270,7 @@ async def create_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: @@ -315,7 +351,7 @@ async def get_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -400,7 +436,7 @@ async def update_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: @@ -478,7 +514,7 @@ async def list_datasets( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -569,7 +605,7 @@ async def delete_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -661,7 +697,7 @@ async def import_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: @@ -754,7 +790,7 @@ async def export_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: @@ -839,7 +875,7 @@ async def list_data_items( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -917,7 +953,7 @@ async def get_annotation_spec( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -992,7 +1028,7 @@ async def list_annotations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index ea66e8921e3..a9c29180322 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -310,6 +310,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -360,57 +427,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -422,6 +454,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -479,7 +520,7 @@ def create_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: @@ -560,7 +601,7 @@ def get_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -645,7 +686,7 @@ def update_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: @@ -723,7 +764,7 @@ def list_datasets( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -814,7 +855,7 @@ def delete_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -906,7 +947,7 @@ def import_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: @@ -999,7 +1040,7 @@ def export_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: @@ -1084,7 +1125,7 @@ def list_data_items( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1162,7 +1203,7 @@ def get_annotation_spec( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1237,7 +1278,7 @@ def list_annotations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index 86eb4018858..0235c7dbd31 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -106,7 +106,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index f29b6bd3be6..b543fd79495 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -165,8 +165,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py index 618a6434fa7..4f8c57d302b 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -210,8 +210,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index 7ffd118709c..245b5047a9d 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return EndpointServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> EndpointServiceTransport: """Returns the transport used by the client instance. @@ -245,7 +281,7 @@ async def create_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: @@ -329,7 +365,7 @@ async def get_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -404,7 +440,7 @@ async def list_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -490,7 +526,7 @@ async def update_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: @@ -579,7 +615,7 @@ async def delete_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -699,7 +735,7 @@ async def deploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: @@ -815,7 +851,7 @@ async def undeploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 3f20e28f330..37d9db65278 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -294,6 +294,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -344,57 +411,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -406,6 +438,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -478,7 +519,7 @@ def create_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: @@ -562,7 +603,7 @@ def get_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -637,7 +678,7 @@ def list_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -723,7 +764,7 @@ def update_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: @@ -812,7 +853,7 @@ def delete_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -932,7 +973,7 @@ def deploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: @@ -1047,7 +1088,7 @@ def undeploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index b523932066c..2e684a1b69b 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -105,7 +105,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index 8eff44667ce..13a59b9cad4 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index 040b449ce21..cc68c244dfb 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py index a758e9a0fe1..639d8122a4d 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py @@ -16,7 +16,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +from typing import ( + Dict, + Optional, + AsyncIterable, + Awaitable, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources from google.api_core.client_options import ClientOptions @@ -120,6 +129,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return FeaturestoreOnlineServingServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> FeaturestoreOnlineServingServiceTransport: """Returns the transport used by the client instance. @@ -227,7 +272,7 @@ async def read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -310,7 +355,7 @@ def streaming_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py index 8bbc0ab87e6..cd70627ef6e 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py @@ -249,6 +249,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -299,57 +366,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, FeaturestoreOnlineServingServiceTransport): # transport is a FeaturestoreOnlineServingServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -361,6 +393,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -417,7 +458,7 @@ def read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -500,7 +541,7 @@ def streaming_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py index 88e5a03c54f..281cdc5ae9e 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py @@ -101,7 +101,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py index 80d729714b1..31340c5f838 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py @@ -160,8 +160,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index b22b01939ab..80ad22e39a3 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -205,8 +205,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py index 4f85bb6b07a..89bcf46a87b 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -129,6 +129,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return FeaturestoreServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> FeaturestoreServiceTransport: """Returns the transport used by the client instance. @@ -253,7 +289,7 @@ async def create_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: @@ -339,7 +375,7 @@ async def get_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -414,7 +450,7 @@ async def list_featurestores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -515,7 +551,7 @@ async def update_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: @@ -625,7 +661,7 @@ async def delete_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -732,7 +768,7 @@ async def create_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: @@ -819,7 +855,7 @@ async def get_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -894,7 +930,7 @@ async def list_entity_types( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -997,7 +1033,7 @@ async def update_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: @@ -1098,7 +1134,7 @@ async def delete_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -1204,7 +1240,7 @@ async def create_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: @@ -1304,7 +1340,7 @@ async def batch_create_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1388,7 +1424,7 @@ async def get_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1463,7 +1499,7 @@ async def list_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1565,7 +1601,7 @@ async def update_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: @@ -1654,7 +1690,7 @@ async def delete_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1759,7 +1795,7 @@ async def import_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -1850,7 +1886,7 @@ async def batch_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: @@ -1937,7 +1973,7 @@ async def export_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2099,7 +2135,7 @@ async def search_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_service/client.py index ea6a73b4afc..510f1080e7f 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/client.py @@ -294,6 +294,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -344,57 +411,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, FeaturestoreServiceTransport): # transport is a FeaturestoreServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -406,6 +438,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -481,7 +522,7 @@ def create_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: @@ -567,7 +608,7 @@ def get_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -642,7 +683,7 @@ def list_featurestores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -743,7 +784,7 @@ def update_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: @@ -853,7 +894,7 @@ def delete_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -960,7 +1001,7 @@ def create_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: @@ -1047,7 +1088,7 @@ def get_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1122,7 +1163,7 @@ def list_entity_types( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1225,7 +1266,7 @@ def update_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: @@ -1326,7 +1367,7 @@ def delete_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -1432,7 +1473,7 @@ def create_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: @@ -1532,7 +1573,7 @@ def batch_create_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1616,7 +1657,7 @@ def get_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1691,7 +1732,7 @@ def list_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1793,7 +1834,7 @@ def update_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: @@ -1882,7 +1923,7 @@ def delete_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1987,7 +2028,7 @@ def import_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2078,7 +2119,7 @@ def batch_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: @@ -2167,7 +2208,7 @@ def export_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2329,7 +2370,7 @@ def search_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1/services/featurestore_service/transports/base.py index f4f71d58c9e..58ce81dd4ac 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/transports/base.py @@ -108,7 +108,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py index 45f0ed6e3bb..7d94516d69b 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py index ec11a1f4b80..cf43b548756 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py index 57f3b519f93..61766a2fa35 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -118,6 +118,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return IndexEndpointServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> IndexEndpointServiceTransport: """Returns the transport used by the client instance. @@ -225,7 +261,7 @@ async def create_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: @@ -308,7 +344,7 @@ async def get_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -383,7 +419,7 @@ async def list_index_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -469,7 +505,7 @@ async def update_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: @@ -558,7 +594,7 @@ async def delete_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -652,7 +688,7 @@ async def deploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: @@ -750,7 +786,7 @@ async def undeploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: @@ -849,7 +885,7 @@ async def mutate_deployed_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py index 770a7989e45..55df2398e59 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py @@ -260,6 +260,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -310,57 +377,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, IndexEndpointServiceTransport): # transport is a IndexEndpointServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -372,6 +404,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -430,7 +471,7 @@ def create_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: @@ -513,7 +554,7 @@ def get_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -588,7 +629,7 @@ def list_index_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -674,7 +715,7 @@ def update_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: @@ -763,7 +804,7 @@ def delete_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -857,7 +898,7 @@ def deploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: @@ -955,7 +996,7 @@ def undeploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: @@ -1054,7 +1095,7 @@ def mutate_deployed_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/base.py index 094ccb33c8b..2363e8864b3 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/base.py @@ -105,7 +105,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py index 340a5d20579..74f1a4b1c34 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py index 928f1e09e09..bf492c172e7 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/index_service/async_client.py b/google/cloud/aiplatform_v1/services/index_service/async_client.py index 4c46cc9e983..6b15fe0b3f6 100644 --- a/google/cloud/aiplatform_v1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -118,6 +118,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return IndexServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> IndexServiceTransport: """Returns the transport used by the client instance. @@ -223,7 +259,7 @@ async def create_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: @@ -306,7 +342,7 @@ async def get_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -381,7 +417,7 @@ async def list_indexes( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -470,7 +506,7 @@ async def update_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: @@ -569,7 +605,7 @@ async def delete_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/index_service/client.py b/google/cloud/aiplatform_v1/services/index_service/client.py index 8d37766d2ad..ac5370dc4c7 100644 --- a/google/cloud/aiplatform_v1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_service/client.py @@ -260,6 +260,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -310,57 +377,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, IndexServiceTransport): # transport is a IndexServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -372,6 +404,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -429,7 +470,7 @@ def create_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: @@ -512,7 +553,7 @@ def get_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -587,7 +628,7 @@ def list_indexes( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -676,7 +717,7 @@ def update_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: @@ -775,7 +816,7 @@ def delete_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/index_service/transports/base.py b/google/cloud/aiplatform_v1/services/index_service/transports/base.py index c6d5eded3b3..52aad2799c7 100644 --- a/google/cloud/aiplatform_v1/services/index_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/index_service/transports/base.py @@ -104,7 +104,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py index e0aa6a22f83..b2126476218 100644 --- a/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py index 2f1880a659e..7d61fd82029 100644 --- a/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 00a77a87883..96a4158a1b0 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -168,6 +168,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return JobServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> JobServiceTransport: """Returns the transport used by the client instance. @@ -277,7 +313,7 @@ async def create_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: @@ -356,7 +392,7 @@ async def get_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -431,7 +467,7 @@ async def list_custom_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -522,7 +558,7 @@ async def delete_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -607,7 +643,7 @@ async def cancel_custom_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -686,7 +722,7 @@ async def create_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: @@ -760,7 +796,7 @@ async def get_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -834,7 +870,7 @@ async def list_data_labeling_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -925,7 +961,7 @@ async def delete_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -999,7 +1035,7 @@ async def cancel_data_labeling_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1080,7 +1116,7 @@ async def create_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: @@ -1156,7 +1192,7 @@ async def get_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1231,7 +1267,7 @@ async def list_hyperparameter_tuning_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1322,7 +1358,7 @@ async def delete_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1409,7 +1445,7 @@ async def cancel_hyperparameter_tuning_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1493,7 +1529,7 @@ async def create_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: @@ -1571,7 +1607,7 @@ async def get_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1646,7 +1682,7 @@ async def list_batch_prediction_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1738,7 +1774,7 @@ async def delete_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1823,7 +1859,7 @@ async def cancel_batch_prediction_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1908,7 +1944,7 @@ async def create_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: @@ -1996,7 +2032,7 @@ async def search_model_deployment_monitoring_stats_anomalies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: @@ -2088,7 +2124,7 @@ async def get_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2165,7 +2201,7 @@ async def list_model_deployment_monitoring_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2282,7 +2318,7 @@ async def update_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: @@ -2386,7 +2422,7 @@ async def delete_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2465,7 +2501,7 @@ async def pause_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2534,7 +2570,7 @@ async def resume_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index 91e232ef97c..65571cda774 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -441,6 +441,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -491,57 +558,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -553,6 +585,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -614,7 +655,7 @@ def create_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: @@ -693,7 +734,7 @@ def get_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -768,7 +809,7 @@ def list_custom_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -859,7 +900,7 @@ def delete_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -944,7 +985,7 @@ def cancel_custom_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1023,7 +1064,7 @@ def create_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: @@ -1097,7 +1138,7 @@ def get_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1171,7 +1212,7 @@ def list_data_labeling_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1262,7 +1303,7 @@ def delete_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1336,7 +1377,7 @@ def cancel_data_labeling_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1417,7 +1458,7 @@ def create_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: @@ -1495,7 +1536,7 @@ def get_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1572,7 +1613,7 @@ def list_hyperparameter_tuning_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1665,7 +1706,7 @@ def delete_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1754,7 +1795,7 @@ def cancel_hyperparameter_tuning_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1840,7 +1881,7 @@ def create_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: @@ -1920,7 +1961,7 @@ def get_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1995,7 +2036,7 @@ def list_batch_prediction_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2089,7 +2130,7 @@ def delete_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2176,7 +2217,7 @@ def cancel_batch_prediction_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2263,7 +2304,7 @@ def create_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: @@ -2357,7 +2398,7 @@ def search_model_deployment_monitoring_stats_anomalies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: @@ -2455,7 +2496,7 @@ def get_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2534,7 +2575,7 @@ def list_model_deployment_monitoring_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2655,7 +2696,7 @@ def update_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: @@ -2765,7 +2806,7 @@ def delete_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2848,7 +2889,7 @@ def pause_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2921,7 +2962,7 @@ def resume_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index e94b1702483..fd11d859f31 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -120,7 +120,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index 8313d01b853..701d3f57e91 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -178,8 +178,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -252,7 +255,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py index 524f69d5fa2..9c0327ffad0 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -223,8 +223,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -254,7 +257,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py index 3acff1b4086..a7ce7c5617e 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -138,6 +138,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return MetadataServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> MetadataServiceTransport: """Returns the transport used by the client instance. @@ -261,7 +297,7 @@ async def create_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: @@ -346,7 +382,7 @@ async def get_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -421,7 +457,7 @@ async def list_metadata_stores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -513,7 +549,7 @@ async def delete_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -613,7 +649,7 @@ async def create_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: @@ -687,7 +723,7 @@ async def get_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -762,7 +798,7 @@ async def list_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -851,7 +887,7 @@ async def update_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: @@ -940,7 +976,7 @@ async def delete_artifact( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1024,7 +1060,7 @@ async def purge_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1124,7 +1160,7 @@ async def create_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: @@ -1198,7 +1234,7 @@ async def get_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1273,7 +1309,7 @@ async def list_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1361,7 +1397,7 @@ async def update_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: @@ -1450,7 +1486,7 @@ async def delete_context( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1534,7 +1570,7 @@ async def purge_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1640,7 +1676,7 @@ async def add_context_artifacts_and_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: @@ -1729,7 +1765,7 @@ async def add_context_children( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: @@ -1814,7 +1850,7 @@ async def query_context_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: @@ -1906,7 +1942,7 @@ async def create_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: @@ -1980,7 +2016,7 @@ async def get_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2055,7 +2091,7 @@ async def list_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2144,7 +2180,7 @@ async def update_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: @@ -2233,7 +2269,7 @@ async def delete_execution( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2317,7 +2353,7 @@ async def purge_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2407,7 +2443,7 @@ async def add_execution_events( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: @@ -2489,7 +2525,7 @@ async def query_execution_inputs_and_outputs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: @@ -2585,7 +2621,7 @@ async def create_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: @@ -2659,7 +2695,7 @@ async def get_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2734,7 +2770,7 @@ async def list_metadata_schemas( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2823,7 +2859,7 @@ async def query_artifact_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/metadata_service/client.py b/google/cloud/aiplatform_v1/services/metadata_service/client.py index d3f74607bf8..4b2f2ee956c 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/client.py @@ -338,6 +338,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -388,57 +455,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, MetadataServiceTransport): # transport is a MetadataServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -450,6 +482,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -525,7 +566,7 @@ def create_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: @@ -610,7 +651,7 @@ def get_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -685,7 +726,7 @@ def list_metadata_stores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -777,7 +818,7 @@ def delete_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -877,7 +918,7 @@ def create_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: @@ -951,7 +992,7 @@ def get_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1026,7 +1067,7 @@ def list_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1115,7 +1156,7 @@ def update_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: @@ -1204,7 +1245,7 @@ def delete_artifact( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1288,7 +1329,7 @@ def purge_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1388,7 +1429,7 @@ def create_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: @@ -1462,7 +1503,7 @@ def get_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1537,7 +1578,7 @@ def list_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1625,7 +1666,7 @@ def update_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: @@ -1714,7 +1755,7 @@ def delete_context( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1798,7 +1839,7 @@ def purge_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1904,7 +1945,7 @@ def add_context_artifacts_and_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: @@ -1997,7 +2038,7 @@ def add_context_children( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: @@ -2082,7 +2123,7 @@ def query_context_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: @@ -2176,7 +2217,7 @@ def create_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: @@ -2250,7 +2291,7 @@ def get_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2325,7 +2366,7 @@ def list_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2414,7 +2455,7 @@ def update_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: @@ -2503,7 +2544,7 @@ def delete_execution( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2587,7 +2628,7 @@ def purge_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2677,7 +2718,7 @@ def add_execution_events( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: @@ -2759,7 +2800,7 @@ def query_execution_inputs_and_outputs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: @@ -2859,7 +2900,7 @@ def create_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: @@ -2933,7 +2974,7 @@ def get_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -3008,7 +3049,7 @@ def list_metadata_schemas( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -3097,7 +3138,7 @@ def query_artifact_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1/services/metadata_service/transports/base.py index 1680f3f1f94..b1c0e631770 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/transports/base.py @@ -113,7 +113,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py index 3451898345f..20965f20ecc 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py @@ -171,8 +171,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -245,7 +248,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py index e029a8b7826..4653c6b8b54 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py @@ -216,8 +216,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -247,7 +250,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index 8ff44ad8bdd..05fcd67854f 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return MigrationServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> MigrationServiceTransport: """Returns the transport used by the client instance. @@ -229,7 +265,7 @@ async def search_migratable_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -326,7 +362,7 @@ async def batch_migrate_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 8e509a04a6e..f50dbfb2f31 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -199,32 +199,32 @@ def parse_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -334,6 +334,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -384,57 +451,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -446,6 +478,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -502,7 +543,7 @@ def search_migratable_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -601,7 +642,7 @@ def batch_migrate_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py index 147276408a0..4eee97b2acb 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py @@ -103,7 +103,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index d705ebe72ce..2fe561b5967 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -162,8 +162,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -236,7 +239,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py index b355ae80c1b..1386630c9f5 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -207,8 +207,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -238,7 +241,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index 8c3caae9b22..ef8a6ea146d 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -132,6 +132,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> ModelServiceTransport: """Returns the transport used by the client instance. @@ -239,7 +275,7 @@ async def upload_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: @@ -318,7 +354,7 @@ async def get_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -393,7 +429,7 @@ async def list_models( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -477,7 +513,7 @@ async def update_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: @@ -573,7 +609,7 @@ async def delete_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -668,7 +704,7 @@ async def export_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: @@ -752,7 +788,7 @@ async def get_model_evaluation( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -827,7 +863,7 @@ async def list_model_evaluations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -907,7 +943,7 @@ async def get_model_evaluation_slice( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -982,7 +1018,7 @@ async def list_model_evaluation_slices( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index 60bdad30c54..c2ba20b31fa 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -320,6 +320,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -370,57 +437,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -432,6 +464,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -491,7 +532,7 @@ def upload_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: @@ -570,7 +611,7 @@ def get_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -645,7 +686,7 @@ def list_models( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -729,7 +770,7 @@ def update_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: @@ -825,7 +866,7 @@ def delete_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -920,7 +961,7 @@ def export_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: @@ -1004,7 +1045,7 @@ def get_model_evaluation( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1079,7 +1120,7 @@ def list_model_evaluations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1159,7 +1200,7 @@ def get_model_evaluation_slice( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1236,7 +1277,7 @@ def list_model_evaluation_slices( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index 5615336d38a..bf146699d00 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -107,7 +107,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index 91751eef777..ad3b1954d34 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -165,8 +165,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py index 71ae4b287e9..171f2aada58 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -210,8 +210,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index c13faa5438b..b48636f058e 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -144,6 +144,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PipelineServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> PipelineServiceTransport: """Returns the transport used by the client instance. @@ -253,7 +289,7 @@ async def create_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: @@ -331,7 +367,7 @@ async def get_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -406,7 +442,7 @@ async def list_training_pipelines( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -497,7 +533,7 @@ async def delete_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -583,7 +619,7 @@ async def cancel_training_pipeline( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -673,7 +709,7 @@ async def create_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: @@ -748,7 +784,7 @@ async def get_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -823,7 +859,7 @@ async def list_pipeline_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -914,7 +950,7 @@ async def delete_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -999,7 +1035,7 @@ async def cancel_pipeline_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index ee9754086e4..9781ff020a2 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -396,6 +396,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -446,57 +513,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -508,6 +540,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -569,7 +610,7 @@ def create_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: @@ -647,7 +688,7 @@ def get_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -722,7 +763,7 @@ def list_training_pipelines( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -813,7 +854,7 @@ def delete_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -899,7 +940,7 @@ def cancel_training_pipeline( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -989,7 +1030,7 @@ def create_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: @@ -1064,7 +1105,7 @@ def get_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1139,7 +1180,7 @@ def list_pipeline_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1230,7 +1271,7 @@ def delete_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1315,7 +1356,7 @@ def cancel_pipeline_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py index 3da341fbb76..b61c52dda08 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -108,7 +108,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index 87d0dafdb1b..19e4a67c9fc 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -169,8 +169,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py index 624014b2589..775ec6d14f4 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -214,8 +214,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -245,7 +248,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index ef0e1a655c2..e6126cd6538 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -111,6 +111,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PredictionServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> PredictionServiceTransport: """Returns the transport used by the client instance. @@ -236,7 +272,7 @@ async def predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: @@ -393,7 +429,7 @@ async def raw_predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: @@ -516,7 +552,7 @@ async def explain( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index e3cb2c5a53b..b145cc7c326 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -255,6 +255,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -305,57 +372,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -367,6 +399,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -390,7 +431,7 @@ def predict( timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> prediction_service.PredictResponse: - """Perform an online prediction. + r"""Perform an online prediction. Args: request (Union[google.cloud.aiplatform_v1.types.PredictRequest, dict]): @@ -444,7 +485,7 @@ def predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: @@ -494,7 +535,7 @@ def raw_predict( timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> httpbody_pb2.HttpBody: - """Perform an online prediction with an arbitrary HTTP payload. + r"""Perform an online prediction with an arbitrary HTTP payload. The response includes the following HTTP headers: @@ -601,7 +642,7 @@ def raw_predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: @@ -651,7 +692,7 @@ def explain( timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> prediction_service.ExplainResponse: - """Perform an online explanation. + r"""Perform an online explanation. If [deployed_model_id][google.cloud.aiplatform.v1.ExplainRequest.deployed_model_id] @@ -724,7 +765,7 @@ def explain( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 4a1d1706f3a..eba0aa22f76 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -102,7 +102,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index 111c5897c5a..c9dcd9eae70 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -159,8 +159,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index f03d7c4aaa3..cb496645c59 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -204,8 +204,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index 9b36163ee17..a1142f9b2af 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return SpecialistPoolServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> SpecialistPoolServiceTransport: """Returns the transport used by the client instance. @@ -240,7 +276,7 @@ async def create_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: @@ -332,7 +368,7 @@ async def get_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -407,7 +443,7 @@ async def list_specialist_pools( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -501,7 +537,7 @@ async def delete_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -598,7 +634,7 @@ async def update_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index 04030b804c0..374c584431c 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -249,6 +249,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -299,57 +366,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -361,6 +393,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -427,7 +468,7 @@ def create_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: @@ -519,7 +560,7 @@ def get_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -594,7 +635,7 @@ def list_specialist_pools( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -688,7 +729,7 @@ def delete_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -785,7 +826,7 @@ def update_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index 86f7871a6ff..74eb77b6ced 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -104,7 +104,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index f621f63d301..a22525269bd 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py index 2a45f210cad..8210feb2486 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py index 3adea38fe5d..f2ea3ffc3bd 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py @@ -16,7 +16,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +from typing import ( + Dict, + Optional, + AsyncIterable, + Awaitable, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources from google.api_core.client_options import ClientOptions @@ -144,6 +153,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return TensorboardServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> TensorboardServiceTransport: """Returns the transport used by the client instance. @@ -251,7 +296,7 @@ async def create_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: @@ -336,7 +381,7 @@ async def get_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -425,7 +470,7 @@ async def update_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: @@ -512,7 +557,7 @@ async def list_tensorboards( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -603,7 +648,7 @@ async def delete_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -705,7 +750,7 @@ async def create_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] @@ -787,7 +832,7 @@ async def get_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -877,7 +922,7 @@ async def update_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: @@ -959,7 +1004,7 @@ async def list_tensorboard_experiments( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1052,7 +1097,7 @@ async def delete_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1154,7 +1199,7 @@ async def create_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: @@ -1245,7 +1290,7 @@ async def batch_create_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1321,7 +1366,7 @@ async def get_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1408,7 +1453,7 @@ async def update_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: @@ -1488,7 +1533,7 @@ async def list_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1579,7 +1624,7 @@ async def delete_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1678,7 +1723,7 @@ async def batch_create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1762,7 +1807,7 @@ async def create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: @@ -1838,7 +1883,7 @@ async def get_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1926,7 +1971,7 @@ async def update_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: @@ -2013,7 +2058,7 @@ async def list_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2106,7 +2151,7 @@ async def delete_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2197,7 +2242,7 @@ async def batch_read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: @@ -2277,7 +2322,7 @@ async def read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2354,7 +2399,7 @@ def read_tensorboard_blob_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: @@ -2442,7 +2487,7 @@ async def write_tensorboard_experiment_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: @@ -2534,7 +2579,7 @@ async def write_tensorboard_run_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: @@ -2616,7 +2661,7 @@ async def export_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py index d8d9c540103..70a4aa8fa45 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py @@ -327,6 +327,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -377,57 +444,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, TensorboardServiceTransport): # transport is a TensorboardServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -439,6 +471,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -497,7 +538,7 @@ def create_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: @@ -582,7 +623,7 @@ def get_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -671,7 +712,7 @@ def update_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: @@ -758,7 +799,7 @@ def list_tensorboards( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -849,7 +890,7 @@ def delete_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -951,7 +992,7 @@ def create_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] @@ -1037,7 +1078,7 @@ def get_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1129,7 +1170,7 @@ def update_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: @@ -1215,7 +1256,7 @@ def list_tensorboard_experiments( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1312,7 +1353,7 @@ def delete_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1418,7 +1459,7 @@ def create_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: @@ -1509,7 +1550,7 @@ def batch_create_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1589,7 +1630,7 @@ def get_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1676,7 +1717,7 @@ def update_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: @@ -1756,7 +1797,7 @@ def list_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1847,7 +1888,7 @@ def delete_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1946,7 +1987,7 @@ def batch_create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -2036,7 +2077,7 @@ def create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2116,7 +2157,7 @@ def get_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2206,7 +2247,7 @@ def update_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: @@ -2297,7 +2338,7 @@ def list_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2394,7 +2435,7 @@ def delete_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2489,7 +2530,7 @@ def batch_read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: @@ -2575,7 +2616,7 @@ def read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2656,7 +2697,7 @@ def read_tensorboard_blob_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: @@ -2746,7 +2787,7 @@ def write_tensorboard_experiment_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: @@ -2842,7 +2883,7 @@ def write_tensorboard_run_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: @@ -2926,7 +2967,7 @@ def export_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py index af7f8f6d7a9..a13a7638dbd 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py @@ -117,7 +117,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py index 84b28c5e567..1055f1035bf 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py @@ -172,8 +172,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -246,7 +249,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py index 68b73e92181..917fcadb018 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py @@ -217,8 +217,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -248,7 +251,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py index ced1379087c..aa974f48ea0 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -119,6 +119,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return VizierServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> VizierServiceTransport: """Returns the transport used by the client instance. @@ -219,12 +255,10 @@ async def create_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: @@ -292,12 +326,10 @@ async def get_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -373,7 +405,7 @@ async def list_studies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -445,7 +477,7 @@ async def delete_study( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -512,12 +544,10 @@ async def lookup_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -663,7 +693,7 @@ async def create_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: @@ -739,7 +769,7 @@ async def get_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -814,7 +844,7 @@ async def list_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -993,7 +1023,7 @@ async def delete_trial( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1184,7 +1214,7 @@ async def list_optimal_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/vizier_service/client.py b/google/cloud/aiplatform_v1/services/vizier_service/client.py index 66a9e425105..4c0423a036d 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/client.py @@ -273,6 +273,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -323,57 +390,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, VizierServiceTransport): # transport is a VizierServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -385,6 +417,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -437,12 +478,10 @@ def create_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: @@ -510,12 +549,10 @@ def get_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -591,7 +628,7 @@ def list_studies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -663,7 +700,7 @@ def delete_study( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -730,12 +767,10 @@ def lookup_study( Returns: google.cloud.aiplatform_v1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -882,7 +917,7 @@ def create_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: @@ -958,7 +993,7 @@ def get_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1033,7 +1068,7 @@ def list_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1214,7 +1249,7 @@ def delete_trial( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1409,7 +1444,7 @@ def list_optimal_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1/services/vizier_service/transports/base.py index a31cdaf9ff4..80033df1ee0 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/transports/base.py @@ -106,7 +106,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py index 9a3aab47d2f..f9f2a52ce75 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py index d9506d89029..e4c04f17373 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1/types/annotation.py b/google/cloud/aiplatform_v1/types/annotation.py index c44c1dba0ef..96c8d48cabf 100644 --- a/google/cloud/aiplatform_v1/types/annotation.py +++ b/google/cloud/aiplatform_v1/types/annotation.py @@ -53,8 +53,8 @@ class Annotation(proto.Message): Output only. Timestamp when this Annotation was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. annotation_source (google.cloud.aiplatform_v1.types.UserActionReference): Output only. The source of the Annotation. diff --git a/google/cloud/aiplatform_v1/types/annotation_spec.py b/google/cloud/aiplatform_v1/types/annotation_spec.py index 626db3df7eb..4b0beab800e 100644 --- a/google/cloud/aiplatform_v1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1/types/annotation_spec.py @@ -43,8 +43,8 @@ class AnnotationSpec(proto.Message): Output only. Timestamp when AnnotationSpec was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1/types/artifact.py b/google/cloud/aiplatform_v1/types/artifact.py index 60426017909..6793c0fb8f7 100644 --- a/google/cloud/aiplatform_v1/types/artifact.py +++ b/google/cloud/aiplatform_v1/types/artifact.py @@ -39,8 +39,8 @@ class Artifact(proto.Message): artifact file. May be empty if there is no actual artifact file. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1.types.Artifact.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1/types/context.py b/google/cloud/aiplatform_v1/types/context.py index ac7285d3a35..e5b08a01d59 100644 --- a/google/cloud/aiplatform_v1/types/context.py +++ b/google/cloud/aiplatform_v1/types/context.py @@ -35,8 +35,8 @@ class Context(proto.Message): User provided display name of the Context. May be up to 128 Unicode characters. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1.types.Context.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index 846de2f6221..3f280a9cdf9 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -121,6 +121,7 @@ class CustomJob(proto.Message): class CustomJobSpec(proto.Message): r"""Represents the spec of a CustomJob. + Next Id: 14 Attributes: worker_pool_specs (Sequence[google.cloud.aiplatform_v1.types.WorkerPoolSpec]): diff --git a/google/cloud/aiplatform_v1/types/data_item.py b/google/cloud/aiplatform_v1/types/data_item.py index 447850e95e8..b83996d4658 100644 --- a/google/cloud/aiplatform_v1/types/data_item.py +++ b/google/cloud/aiplatform_v1/types/data_item.py @@ -60,8 +60,8 @@ class DataItem(proto.Message): schema's][google.cloud.aiplatform.v1.Dataset.metadata_schema_uri] dataItemSchemaUri field. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1/types/data_labeling_job.py b/google/cloud/aiplatform_v1/types/data_labeling_job.py index 63ff9a0b553..e908faecb56 100644 --- a/google/cloud/aiplatform_v1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1/types/data_labeling_job.py @@ -76,9 +76,9 @@ class DataLabelingJob(proto.Message): Google Cloud Storage describing the config for a specific type of DataLabelingJob. The schema files that can be used here are found in the - https://storage.googleapis.com/google-cloud- - aiplatform bucket in the - /schema/datalabelingjob/inputs/ folder. + https://storage.googleapis.com/google-cloud-aiplatform + bucket in the /schema/datalabelingjob/inputs/ + folder. inputs (google.protobuf.struct_pb2.Value): Required. Input config parameters for the DataLabelingJob. diff --git a/google/cloud/aiplatform_v1/types/dataset.py b/google/cloud/aiplatform_v1/types/dataset.py index ef7a466d6b3..c7207d93e8b 100644 --- a/google/cloud/aiplatform_v1/types/dataset.py +++ b/google/cloud/aiplatform_v1/types/dataset.py @@ -46,8 +46,7 @@ class Dataset(proto.Message): information about the Dataset. The schema is defined as an OpenAPI 3.0.2 Schema Object. The schema files that can be used here are found in - gs://google-cloud- - aiplatform/schema/dataset/metadata/. + gs://google-cloud-aiplatform/schema/dataset/metadata/. metadata (google.protobuf.struct_pb2.Value): Required. Additional information about the Dataset. @@ -82,8 +81,8 @@ class Dataset(proto.Message): title. encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): Customer-managed encryption key spec for a - Dataset. If set, this Dataset and all sub- - resources of this Dataset will be secured by + Dataset. If set, this Dataset and all + sub-resources of this Dataset will be secured by this key. """ diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py index 8ea223fc73d..66f0cfa79ef 100644 --- a/google/cloud/aiplatform_v1/types/endpoint.py +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -80,9 +80,9 @@ class Endpoint(proto.Message): last updated. encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): Customer-managed encryption key spec for an - Endpoint. If set, this Endpoint and all sub- - resources of this Endpoint will be secured by - this key. + Endpoint. If set, this Endpoint and all + sub-resources of this Endpoint will be secured + by this key. network (str): The full name of the Google Compute Engine `network `__ diff --git a/google/cloud/aiplatform_v1/types/entity_type.py b/google/cloud/aiplatform_v1/types/entity_type.py index ce414207823..680b2b2f3cd 100644 --- a/google/cloud/aiplatform_v1/types/entity_type.py +++ b/google/cloud/aiplatform_v1/types/entity_type.py @@ -62,8 +62,8 @@ class EntityType(proto.Message): System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Optional. Used to perform a consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform a consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1/types/execution.py b/google/cloud/aiplatform_v1/types/execution.py index 52acfc61aa3..0522a6e612e 100644 --- a/google/cloud/aiplatform_v1/types/execution.py +++ b/google/cloud/aiplatform_v1/types/execution.py @@ -42,8 +42,8 @@ class Execution(proto.Message): and the system does not prescribe or check the validity of state transitions. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1.types.Execution.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1/types/explanation.py b/google/cloud/aiplatform_v1/types/explanation.py index 33a688fbe0a..c496a070e37 100644 --- a/google/cloud/aiplatform_v1/types/explanation.py +++ b/google/cloud/aiplatform_v1/types/explanation.py @@ -280,10 +280,10 @@ class ExplanationParameters(proto.Message): This field is a member of `oneof`_ ``method``. integrated_gradients_attribution (google.cloud.aiplatform_v1.types.IntegratedGradientsAttribution): - An attribution method that computes Aumann- - hapley values taking advantage of the model's - fully differentiable structure. Refer to this - paper for more details: + An attribution method that computes + Aumann-Shapley values taking advantage of the + model's fully differentiable structure. Refer to + this paper for more details: https://arxiv.org/abs/1703.01365 This field is a member of `oneof`_ ``method``. diff --git a/google/cloud/aiplatform_v1/types/explanation_metadata.py b/google/cloud/aiplatform_v1/types/explanation_metadata.py index 2c6c45cd82e..eb033ea3517 100644 --- a/google/cloud/aiplatform_v1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1/types/explanation_metadata.py @@ -100,8 +100,8 @@ class InputMetadata(proto.Message): [instance_schema_uri][google.cloud.aiplatform.v1.PredictSchemata.instance_schema_uri]. input_tensor_name (str): Name of the input tensor for this feature. - Required and is only applicable to Vertex AI- - provided images for Tensorflow. + Required and is only applicable to Vertex + AI-provided images for Tensorflow. encoding (google.cloud.aiplatform_v1.types.ExplanationMetadata.InputMetadata.Encoding): Defines how the feature is encoded into the input tensor. Defaults to IDENTITY. diff --git a/google/cloud/aiplatform_v1/types/feature.py b/google/cloud/aiplatform_v1/types/feature.py index 9febc539a45..17b873538d2 100644 --- a/google/cloud/aiplatform_v1/types/feature.py +++ b/google/cloud/aiplatform_v1/types/feature.py @@ -63,9 +63,9 @@ class Feature(proto.Message): System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. """ class ValueType(proto.Enum): diff --git a/google/cloud/aiplatform_v1/types/featurestore.py b/google/cloud/aiplatform_v1/types/featurestore.py index 0f706dcffc0..09600a2a09f 100644 --- a/google/cloud/aiplatform_v1/types/featurestore.py +++ b/google/cloud/aiplatform_v1/types/featurestore.py @@ -40,8 +40,8 @@ class Featurestore(proto.Message): Output only. Timestamp when this Featurestore was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1.types.Featurestore.LabelsEntry]): Optional. The labels with user-defined @@ -81,10 +81,11 @@ class OnlineServingConfig(proto.Message): Attributes: fixed_node_count (int): - The number of nodes for each cluster. The - number of nodes will not scale automatically but - can be scaled manually by providing different - values when updating. + The number of nodes for each cluster. The number of nodes + will not scale automatically but can be scaled manually by + providing different values when updating. Only one of + ``fixed_node_count`` and ``scaling`` can be set. Setting one + will reset the other. """ fixed_node_count = proto.Field(proto.INT32, number=2,) diff --git a/google/cloud/aiplatform_v1/types/index_endpoint.py b/google/cloud/aiplatform_v1/types/index_endpoint.py index 6f2edb034e6..b5c8567825a 100644 --- a/google/cloud/aiplatform_v1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1/types/index_endpoint.py @@ -171,6 +171,15 @@ class DeployedIndex(proto.Message): don't provide SLA when min_replica_count=1). If max_replica_count is not set, the default value is min_replica_count. The max allowed replica count is 1000. + dedicated_resources (google.cloud.aiplatform_v1.types.DedicatedResources): + Optional. A description of resources that are dedicated to + the DeployedIndex, and that need a higher degree of manual + configuration. If min_replica_count is not set, the default + value is 2 (we don't provide SLA when min_replica_count=1). + If max_replica_count is not set, the default value is + min_replica_count. The max allowed replica count is 1000. + + Available machine types: n1-standard-16 n1-standard-32 enable_access_logging (bool): Optional. If true, private endpoint's access logs are sent to StackDriver Logging. @@ -227,6 +236,9 @@ class DeployedIndex(proto.Message): automatic_resources = proto.Field( proto.MESSAGE, number=7, message=machine_resources.AutomaticResources, ) + dedicated_resources = proto.Field( + proto.MESSAGE, number=16, message=machine_resources.DedicatedResources, + ) enable_access_logging = proto.Field(proto.BOOL, number=8,) deployed_index_auth_config = proto.Field( proto.MESSAGE, number=9, message="DeployedIndexAuthConfig", diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py index ca3a9644479..8c97fd09425 100644 --- a/google/cloud/aiplatform_v1/types/model.py +++ b/google/cloud/aiplatform_v1/types/model.py @@ -397,7 +397,7 @@ class ModelContainerSpec(proto.Message): r"""Specification of a container for serving predictions. Some fields in this message correspond to fields in the `Kubernetes Container v1 core - specification `__. + specification `__. Attributes: image_uri (str): @@ -463,7 +463,7 @@ class ModelContainerSpec(proto.Message): this syntax with ``$$``; for example: $$(VARIABLE_NAME) This field corresponds to the ``command`` field of the Kubernetes Containers `v1 core - API `__. + API `__. args (Sequence[str]): Immutable. Specifies arguments for the command that runs when the container starts. This overrides the container's @@ -502,7 +502,7 @@ class ModelContainerSpec(proto.Message): this syntax with ``$$``; for example: $$(VARIABLE_NAME) This field corresponds to the ``args`` field of the Kubernetes Containers `v1 core - API `__. + API `__. env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]): Immutable. List of environment variables to set in the container. After the container starts running, code running @@ -535,7 +535,7 @@ class ModelContainerSpec(proto.Message): This field corresponds to the ``env`` field of the Kubernetes Containers `v1 core - API `__. + API `__. ports (Sequence[google.cloud.aiplatform_v1.types.Port]): Immutable. List of ports to expose from the container. Vertex AI sends any prediction requests that it receives to @@ -558,7 +558,7 @@ class ModelContainerSpec(proto.Message): Vertex AI does not use ports other than the first one listed. This field corresponds to the ``ports`` field of the Kubernetes Containers `v1 core - API `__. + API `__. predict_route (str): Immutable. HTTP path on the container to send prediction requests to. Vertex AI forwards requests sent using diff --git a/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py index 21aba235b22..5a7a7bb061b 100644 --- a/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py @@ -158,9 +158,10 @@ class ModelDeploymentMonitoringJob(proto.Message): encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): Customer-managed encryption key spec for a ModelDeploymentMonitoringJob. If set, this - ModelDeploymentMonitoringJob and all sub- - resources of this ModelDeploymentMonitoringJob - will be secured by this key. + ModelDeploymentMonitoringJob and all + sub-resources of this + ModelDeploymentMonitoringJob will be secured by + this key. enable_monitoring_pipeline_logs (bool): If true, the scheduled monitoring pipeline logs are sent to Google Cloud Logging, including pipeline status and @@ -283,9 +284,10 @@ class ModelDeploymentMonitoringScheduleConfig(proto.Message): Attributes: monitor_interval (google.protobuf.duration_pb2.Duration): - Required. The model monitoring job running + Required. The model monitoring job scheduling interval. It will be rounded up to next full - hour. + hour. This defines how often the monitoring jobs + are triggered. """ monitor_interval = proto.Field( diff --git a/google/cloud/aiplatform_v1/types/study.py b/google/cloud/aiplatform_v1/types/study.py index e8d807a2ddb..904d6f600af 100644 --- a/google/cloud/aiplatform_v1/types/study.py +++ b/google/cloud/aiplatform_v1/types/study.py @@ -27,8 +27,7 @@ class Study(proto.Message): - r"""LINT.IfChange - A message representing a Study. + r"""A message representing a Study. Attributes: name (str): diff --git a/google/cloud/aiplatform_v1/types/tensorboard.py b/google/cloud/aiplatform_v1/types/tensorboard.py index 68360118440..9628d01fa26 100644 --- a/google/cloud/aiplatform_v1/types/tensorboard.py +++ b/google/cloud/aiplatform_v1/types/tensorboard.py @@ -74,9 +74,9 @@ class Tensorboard(proto.Message): keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. """ name = proto.Field(proto.STRING, number=1,) diff --git a/google/cloud/aiplatform_v1/types/tensorboard_run.py b/google/cloud/aiplatform_v1/types/tensorboard_run.py index c127a2c8f44..aeaf1cee297 100644 --- a/google/cloud/aiplatform_v1/types/tensorboard_run.py +++ b/google/cloud/aiplatform_v1/types/tensorboard_run.py @@ -68,9 +68,9 @@ class TensorboardRun(proto.Message): of labels. System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. """ name = proto.Field(proto.STRING, number=1,) diff --git a/google/cloud/aiplatform_v1/types/tensorboard_time_series.py b/google/cloud/aiplatform_v1/types/tensorboard_time_series.py index 04c0f661b7d..c9ea5aa998c 100644 --- a/google/cloud/aiplatform_v1/types/tensorboard_time_series.py +++ b/google/cloud/aiplatform_v1/types/tensorboard_time_series.py @@ -49,9 +49,9 @@ class TensorboardTimeSeries(proto.Message): Output only. Timestamp when this TensorboardTimeSeries was last updated. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. plugin_name (str): Immutable. Name of the plugin this time series pertain to. Such as Scalar, Tensor, Blob diff --git a/google/cloud/aiplatform_v1/types/training_pipeline.py b/google/cloud/aiplatform_v1/types/training_pipeline.py index 93f4449eced..994f614d8a4 100644 --- a/google/cloud/aiplatform_v1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1/types/training_pipeline.py @@ -68,13 +68,13 @@ class TrainingPipeline(proto.Message): is responsible for producing the model artifact, and may also include additional auxiliary work. The definition files that can be used here are - found in gs://google-cloud- - aiplatform/schema/trainingjob/definition/. Note: - The URI given on output will be immutable and - probably different, including the URI scheme, - than the one given on input. The output URI will - point to a location where the user only has a - read access. + found in + gs://google-cloud-aiplatform/schema/trainingjob/definition/. + Note: The URI given on output will be immutable + and probably different, including the URI + scheme, than the one given on input. The output + URI will point to a location where the user only + has a read access. training_task_inputs (google.protobuf.struct_pb2.Value): Required. The training task's parameter(s), as specified in the diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index b91bb39757d..452872c7fa0 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -129,6 +129,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return DatasetServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> DatasetServiceTransport: """Returns the transport used by the client instance. @@ -234,7 +270,7 @@ async def create_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: @@ -315,7 +351,7 @@ async def get_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -400,7 +436,7 @@ async def update_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: @@ -478,7 +514,7 @@ async def list_datasets( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -569,7 +605,7 @@ async def delete_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -661,7 +697,7 @@ async def import_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: @@ -754,7 +790,7 @@ async def export_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: @@ -839,7 +875,7 @@ async def list_data_items( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -917,7 +953,7 @@ async def get_annotation_spec( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -992,7 +1028,7 @@ async def list_annotations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 8f94725a64e..3eb34817721 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -310,6 +310,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -360,57 +427,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -422,6 +454,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -479,7 +520,7 @@ def create_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: @@ -560,7 +601,7 @@ def get_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -645,7 +686,7 @@ def update_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: @@ -723,7 +764,7 @@ def list_datasets( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -814,7 +855,7 @@ def delete_dataset( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -906,7 +947,7 @@ def import_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: @@ -999,7 +1040,7 @@ def export_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: @@ -1084,7 +1125,7 @@ def list_data_items( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1162,7 +1203,7 @@ def get_annotation_spec( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1237,7 +1278,7 @@ def list_annotations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 50109d924a4..4efd26a64d6 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -106,7 +106,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 6d0916b8197..226c74e3b6a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -165,8 +165,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index 2ace47cd68a..c80e6765856 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -210,8 +210,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 9511c58a96e..689570f6cea 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return EndpointServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> EndpointServiceTransport: """Returns the transport used by the client instance. @@ -245,7 +281,7 @@ async def create_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: @@ -329,7 +365,7 @@ async def get_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -404,7 +440,7 @@ async def list_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -490,7 +526,7 @@ async def update_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: @@ -579,7 +615,7 @@ async def delete_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -699,7 +735,7 @@ async def deploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: @@ -815,7 +851,7 @@ async def undeploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 247fc94dcda..d9af51ae743 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -294,6 +294,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -344,57 +411,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -406,6 +438,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -478,7 +519,7 @@ def create_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: @@ -562,7 +603,7 @@ def get_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -637,7 +678,7 @@ def list_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -723,7 +764,7 @@ def update_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: @@ -812,7 +853,7 @@ def delete_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -932,7 +973,7 @@ def deploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: @@ -1047,7 +1088,7 @@ def undeploy_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 31126b6a54a..90302bcbb7f 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -105,7 +105,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 8892dff9c5d..aa4bd25678a 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 196f7665f3e..8afd47a1c79 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py index 73322758b55..c26fa4b401c 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py @@ -16,7 +16,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +from typing import ( + Dict, + Optional, + AsyncIterable, + Awaitable, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources from google.api_core.client_options import ClientOptions @@ -120,6 +129,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return FeaturestoreOnlineServingServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> FeaturestoreOnlineServingServiceTransport: """Returns the transport used by the client instance. @@ -227,7 +272,7 @@ async def read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -310,7 +355,7 @@ def streaming_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py index 079786dd1d6..bbc30e3fb67 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py @@ -249,6 +249,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -299,57 +366,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, FeaturestoreOnlineServingServiceTransport): # transport is a FeaturestoreOnlineServingServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -361,6 +393,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -417,7 +458,7 @@ def read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -500,7 +541,7 @@ def streaming_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py index 5ecb44843f1..7d775e8d43a 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py @@ -101,7 +101,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py index 68904736ac4..5c725ab45c0 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py @@ -160,8 +160,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index 83a4779a2f5..99f12c483fe 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -205,8 +205,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index cbe4cf700b3..6d6a4bb32ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -131,6 +131,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return FeaturestoreServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> FeaturestoreServiceTransport: """Returns the transport used by the client instance. @@ -255,7 +291,7 @@ async def create_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: @@ -341,7 +377,7 @@ async def get_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -416,7 +452,7 @@ async def list_featurestores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -517,7 +553,7 @@ async def update_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: @@ -627,7 +663,7 @@ async def delete_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -734,7 +770,7 @@ async def create_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: @@ -821,7 +857,7 @@ async def get_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -896,7 +932,7 @@ async def list_entity_types( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -999,7 +1035,7 @@ async def update_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: @@ -1100,7 +1136,7 @@ async def delete_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -1206,7 +1242,7 @@ async def create_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: @@ -1306,7 +1342,7 @@ async def batch_create_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1390,7 +1426,7 @@ async def get_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1465,7 +1501,7 @@ async def list_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1567,7 +1603,7 @@ async def update_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: @@ -1656,7 +1692,7 @@ async def delete_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1761,7 +1797,7 @@ async def import_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -1852,7 +1888,7 @@ async def batch_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: @@ -1939,7 +1975,7 @@ async def export_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2101,7 +2137,7 @@ async def search_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index a149a9c4884..12b0ab724cb 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -296,6 +296,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -346,57 +413,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, FeaturestoreServiceTransport): # transport is a FeaturestoreServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -408,6 +440,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -483,7 +524,7 @@ def create_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: @@ -569,7 +610,7 @@ def get_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -644,7 +685,7 @@ def list_featurestores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -745,7 +786,7 @@ def update_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: @@ -855,7 +896,7 @@ def delete_featurestore( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -962,7 +1003,7 @@ def create_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: @@ -1049,7 +1090,7 @@ def get_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1124,7 +1165,7 @@ def list_entity_types( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1227,7 +1268,7 @@ def update_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: @@ -1328,7 +1369,7 @@ def delete_entity_type( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: @@ -1434,7 +1475,7 @@ def create_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: @@ -1534,7 +1575,7 @@ def batch_create_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1618,7 +1659,7 @@ def get_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1693,7 +1734,7 @@ def list_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1795,7 +1836,7 @@ def update_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: @@ -1884,7 +1925,7 @@ def delete_feature( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1989,7 +2030,7 @@ def import_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2080,7 +2121,7 @@ def batch_read_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: @@ -2169,7 +2210,7 @@ def export_feature_values( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: @@ -2331,7 +2372,7 @@ def search_features( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py index eb864548777..4c6c697cbd0 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py @@ -108,7 +108,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py index 02c4ae728c9..26131b797ff 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py index 56d88792654..eb1b8e6fe61 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py index a1e5aa81b89..17d4aba21ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -118,6 +118,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return IndexEndpointServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> IndexEndpointServiceTransport: """Returns the transport used by the client instance. @@ -225,7 +261,7 @@ async def create_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: @@ -308,7 +344,7 @@ async def get_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -383,7 +419,7 @@ async def list_index_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -469,7 +505,7 @@ async def update_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: @@ -558,7 +594,7 @@ async def delete_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -652,7 +688,7 @@ async def deploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: @@ -750,7 +786,7 @@ async def undeploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: @@ -849,7 +885,7 @@ async def mutate_deployed_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index 49f4dc9da63..99f1480e274 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -260,6 +260,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -310,57 +377,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, IndexEndpointServiceTransport): # transport is a IndexEndpointServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -372,6 +404,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -430,7 +471,7 @@ def create_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: @@ -513,7 +554,7 @@ def get_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -588,7 +629,7 @@ def list_index_endpoints( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -674,7 +715,7 @@ def update_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: @@ -763,7 +804,7 @@ def delete_index_endpoint( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -857,7 +898,7 @@ def deploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: @@ -955,7 +996,7 @@ def undeploy_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: @@ -1054,7 +1095,7 @@ def mutate_deployed_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py index 729e32879bf..cdf3473da1f 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py @@ -105,7 +105,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py index 5704bc41f49..6c1d31633d9 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py index e8b2c2ccafd..8a17791620e 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py index 0838cdde225..87214d756d5 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -118,6 +118,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return IndexServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> IndexServiceTransport: """Returns the transport used by the client instance. @@ -223,7 +259,7 @@ async def create_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: @@ -306,7 +342,7 @@ async def get_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -381,7 +417,7 @@ async def list_indexes( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -470,7 +506,7 @@ async def update_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: @@ -569,7 +605,7 @@ async def delete_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index 57a05df2551..da3dc18d1fd 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -260,6 +260,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -310,57 +377,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, IndexServiceTransport): # transport is a IndexServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -372,6 +404,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -429,7 +470,7 @@ def create_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: @@ -512,7 +553,7 @@ def get_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -587,7 +628,7 @@ def list_indexes( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -676,7 +717,7 @@ def update_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: @@ -775,7 +816,7 @@ def delete_index( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py index cfb92661218..4c99f598549 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py @@ -104,7 +104,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py index 8173064eafe..e782bd8739b 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py @@ -163,8 +163,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -237,7 +240,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py index af440edd6d2..23209aea24f 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py @@ -208,8 +208,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 6022a7c16a0..ae2989cefe1 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -170,6 +170,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return JobServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> JobServiceTransport: """Returns the transport used by the client instance. @@ -279,7 +315,7 @@ async def create_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: @@ -358,7 +394,7 @@ async def get_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -433,7 +469,7 @@ async def list_custom_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -524,7 +560,7 @@ async def delete_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -609,7 +645,7 @@ async def cancel_custom_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -688,7 +724,7 @@ async def create_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: @@ -762,7 +798,7 @@ async def get_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -836,7 +872,7 @@ async def list_data_labeling_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -927,7 +963,7 @@ async def delete_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1001,7 +1037,7 @@ async def cancel_data_labeling_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1082,7 +1118,7 @@ async def create_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: @@ -1158,7 +1194,7 @@ async def get_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1233,7 +1269,7 @@ async def list_hyperparameter_tuning_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1324,7 +1360,7 @@ async def delete_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1411,7 +1447,7 @@ async def cancel_hyperparameter_tuning_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1495,7 +1531,7 @@ async def create_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: @@ -1573,7 +1609,7 @@ async def get_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1648,7 +1684,7 @@ async def list_batch_prediction_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1740,7 +1776,7 @@ async def delete_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1825,7 +1861,7 @@ async def cancel_batch_prediction_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1910,7 +1946,7 @@ async def create_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: @@ -1998,7 +2034,7 @@ async def search_model_deployment_monitoring_stats_anomalies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: @@ -2090,7 +2126,7 @@ async def get_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2167,7 +2203,7 @@ async def list_model_deployment_monitoring_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2284,7 +2320,7 @@ async def update_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: @@ -2388,7 +2424,7 @@ async def delete_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2467,7 +2503,7 @@ async def pause_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2536,7 +2572,7 @@ async def resume_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 57d840f33e3..85e571b49e8 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -443,6 +443,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -493,57 +560,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -555,6 +587,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -616,7 +657,7 @@ def create_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: @@ -695,7 +736,7 @@ def get_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -770,7 +811,7 @@ def list_custom_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -861,7 +902,7 @@ def delete_custom_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -946,7 +987,7 @@ def cancel_custom_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1025,7 +1066,7 @@ def create_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: @@ -1099,7 +1140,7 @@ def get_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1173,7 +1214,7 @@ def list_data_labeling_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1264,7 +1305,7 @@ def delete_data_labeling_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1338,7 +1379,7 @@ def cancel_data_labeling_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1419,7 +1460,7 @@ def create_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: @@ -1497,7 +1538,7 @@ def get_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1574,7 +1615,7 @@ def list_hyperparameter_tuning_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1667,7 +1708,7 @@ def delete_hyperparameter_tuning_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1756,7 +1797,7 @@ def cancel_hyperparameter_tuning_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1842,7 +1883,7 @@ def create_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: @@ -1922,7 +1963,7 @@ def get_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1997,7 +2038,7 @@ def list_batch_prediction_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2091,7 +2132,7 @@ def delete_batch_prediction_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2178,7 +2219,7 @@ def cancel_batch_prediction_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2265,7 +2306,7 @@ def create_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: @@ -2359,7 +2400,7 @@ def search_model_deployment_monitoring_stats_anomalies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: @@ -2457,7 +2498,7 @@ def get_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2536,7 +2577,7 @@ def list_model_deployment_monitoring_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2657,7 +2698,7 @@ def update_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: @@ -2767,7 +2808,7 @@ def delete_model_deployment_monitoring_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2850,7 +2891,7 @@ def pause_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2923,7 +2964,7 @@ def resume_model_deployment_monitoring_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 75f12262677..e5f24d7f136 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -122,7 +122,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index baaa922a450..b30b5ffb58a 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -180,8 +180,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -254,7 +257,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 02c6188be4b..4cf4e546924 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -225,8 +225,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -256,7 +259,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index e58591c5417..6602a9becae 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -138,6 +138,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return MetadataServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> MetadataServiceTransport: """Returns the transport used by the client instance. @@ -261,7 +297,7 @@ async def create_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: @@ -346,7 +382,7 @@ async def get_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -421,7 +457,7 @@ async def list_metadata_stores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -513,7 +549,7 @@ async def delete_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -613,7 +649,7 @@ async def create_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: @@ -687,7 +723,7 @@ async def get_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -762,7 +798,7 @@ async def list_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -851,7 +887,7 @@ async def update_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: @@ -940,7 +976,7 @@ async def delete_artifact( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1024,7 +1060,7 @@ async def purge_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1124,7 +1160,7 @@ async def create_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: @@ -1198,7 +1234,7 @@ async def get_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1273,7 +1309,7 @@ async def list_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1361,7 +1397,7 @@ async def update_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: @@ -1450,7 +1486,7 @@ async def delete_context( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1534,7 +1570,7 @@ async def purge_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1640,7 +1676,7 @@ async def add_context_artifacts_and_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: @@ -1729,7 +1765,7 @@ async def add_context_children( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: @@ -1814,7 +1850,7 @@ async def query_context_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: @@ -1906,7 +1942,7 @@ async def create_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: @@ -1980,7 +2016,7 @@ async def get_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2055,7 +2091,7 @@ async def list_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2144,7 +2180,7 @@ async def update_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: @@ -2233,7 +2269,7 @@ async def delete_execution( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2317,7 +2353,7 @@ async def purge_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2407,7 +2443,7 @@ async def add_execution_events( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: @@ -2489,7 +2525,7 @@ async def query_execution_inputs_and_outputs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: @@ -2585,7 +2621,7 @@ async def create_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: @@ -2659,7 +2695,7 @@ async def get_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2734,7 +2770,7 @@ async def list_metadata_schemas( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2823,7 +2859,7 @@ async def query_artifact_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index d0f873ad9e7..01fc9bf5d7f 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -338,6 +338,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -388,57 +455,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, MetadataServiceTransport): # transport is a MetadataServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -450,6 +482,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -525,7 +566,7 @@ def create_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: @@ -610,7 +651,7 @@ def get_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -685,7 +726,7 @@ def list_metadata_stores( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -777,7 +818,7 @@ def delete_metadata_store( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -877,7 +918,7 @@ def create_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: @@ -951,7 +992,7 @@ def get_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1026,7 +1067,7 @@ def list_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1115,7 +1156,7 @@ def update_artifact( Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: @@ -1204,7 +1245,7 @@ def delete_artifact( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1288,7 +1329,7 @@ def purge_artifacts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1388,7 +1429,7 @@ def create_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: @@ -1462,7 +1503,7 @@ def get_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1537,7 +1578,7 @@ def list_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1625,7 +1666,7 @@ def update_context( Instance of a general context. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: @@ -1714,7 +1755,7 @@ def delete_context( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1798,7 +1839,7 @@ def purge_contexts( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1904,7 +1945,7 @@ def add_context_artifacts_and_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: @@ -1997,7 +2038,7 @@ def add_context_children( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: @@ -2082,7 +2123,7 @@ def query_context_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: @@ -2176,7 +2217,7 @@ def create_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: @@ -2250,7 +2291,7 @@ def get_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2325,7 +2366,7 @@ def list_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2414,7 +2455,7 @@ def update_execution( Instance of a general execution. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: @@ -2503,7 +2544,7 @@ def delete_execution( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2587,7 +2628,7 @@ def purge_executions( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2677,7 +2718,7 @@ def add_execution_events( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: @@ -2759,7 +2800,7 @@ def query_execution_inputs_and_outputs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: @@ -2859,7 +2900,7 @@ def create_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: @@ -2933,7 +2974,7 @@ def get_metadata_schema( Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -3008,7 +3049,7 @@ def list_metadata_schemas( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -3097,7 +3138,7 @@ def query_artifact_lineage_subgraph( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py index d2ceb126be9..64e5528bde0 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -113,7 +113,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index 2ec4b66c376..da626e07ed3 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -171,8 +171,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -245,7 +248,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index d990fbbdad2..ffb4d9b1cd3 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -216,8 +216,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -247,7 +250,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index d16eea5d41d..6aaadc9020f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return MigrationServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> MigrationServiceTransport: """Returns the transport used by the client instance. @@ -229,7 +265,7 @@ async def search_migratable_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -326,7 +362,7 @@ async def batch_migrate_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 5463d3db1c0..24b0de3a989 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -183,32 +183,32 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod @@ -334,6 +334,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -384,57 +451,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -446,6 +478,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -502,7 +543,7 @@ def search_migratable_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -601,7 +642,7 @@ def batch_migrate_resources( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index 5557f97a465..b945690ed99 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -103,7 +103,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 1c2935c97c5..265f9eb67c8 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -162,8 +162,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -236,7 +239,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index a5f5e33a8bf..5add8ae01b5 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -207,8 +207,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -238,7 +241,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 028e8bb41ac..c7c086c876c 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -132,6 +132,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> ModelServiceTransport: """Returns the transport used by the client instance. @@ -239,7 +275,7 @@ async def upload_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: @@ -318,7 +354,7 @@ async def get_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -393,7 +429,7 @@ async def list_models( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -477,7 +513,7 @@ async def update_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: @@ -574,7 +610,7 @@ async def delete_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -669,7 +705,7 @@ async def export_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: @@ -753,7 +789,7 @@ async def get_model_evaluation( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -828,7 +864,7 @@ async def list_model_evaluations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -908,7 +944,7 @@ async def get_model_evaluation_slice( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -983,7 +1019,7 @@ async def list_model_evaluation_slices( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 1ab512337a1..086243f7772 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -320,6 +320,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -370,57 +437,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -432,6 +464,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -491,7 +532,7 @@ def upload_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: @@ -570,7 +611,7 @@ def get_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -645,7 +686,7 @@ def list_models( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -729,7 +770,7 @@ def update_model( A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: @@ -826,7 +867,7 @@ def delete_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -921,7 +962,7 @@ def export_model( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: @@ -1005,7 +1046,7 @@ def get_model_evaluation( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1080,7 +1121,7 @@ def list_model_evaluations( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1160,7 +1201,7 @@ def get_model_evaluation_slice( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1237,7 +1278,7 @@ def list_model_evaluation_slices( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index fa0da90bf07..0e9c8503f78 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -107,7 +107,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 4e4d32fe03d..e318fa49c69 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -165,8 +165,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -239,7 +242,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index f3fa67b56d9..9cd568489b8 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -210,8 +210,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index c31048b9ebd..0f884c970b8 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -146,6 +146,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PipelineServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> PipelineServiceTransport: """Returns the transport used by the client instance. @@ -255,7 +291,7 @@ async def create_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: @@ -333,7 +369,7 @@ async def get_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -408,7 +444,7 @@ async def list_training_pipelines( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -499,7 +535,7 @@ async def delete_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -585,7 +621,7 @@ async def cancel_training_pipeline( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -675,7 +711,7 @@ async def create_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: @@ -750,7 +786,7 @@ async def get_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -825,7 +861,7 @@ async def list_pipeline_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -916,7 +952,7 @@ async def delete_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1001,7 +1037,7 @@ async def cancel_pipeline_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 35c2ffab767..4a87274ada9 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -398,6 +398,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -448,57 +515,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -510,6 +542,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -571,7 +612,7 @@ def create_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: @@ -649,7 +690,7 @@ def get_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -724,7 +765,7 @@ def list_training_pipelines( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -815,7 +856,7 @@ def delete_training_pipeline( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -901,7 +942,7 @@ def cancel_training_pipeline( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -991,7 +1032,7 @@ def create_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: @@ -1066,7 +1107,7 @@ def get_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1141,7 +1182,7 @@ def list_pipeline_jobs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1232,7 +1273,7 @@ def delete_pipeline_job( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1317,7 +1358,7 @@ def cancel_pipeline_job( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 3c1512e2cd6..cb061f614f8 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -110,7 +110,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 372e193e47e..df0b9c8152a 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -171,8 +171,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -245,7 +248,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index d5c7e82b301..5bf038e01e8 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -216,8 +216,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -247,7 +250,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index fd0677b9224..c80296e6584 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -111,6 +111,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PredictionServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> PredictionServiceTransport: """Returns the transport used by the client instance. @@ -236,7 +272,7 @@ async def predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: @@ -393,7 +429,7 @@ async def raw_predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: @@ -516,7 +552,7 @@ async def explain( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 020817c9866..4fa9fabd7a0 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -255,6 +255,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -305,57 +372,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -367,6 +399,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -444,7 +485,7 @@ def predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: @@ -601,7 +642,7 @@ def raw_predict( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: @@ -724,7 +765,7 @@ def explain( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 8ea55559e6f..c58c155c89e 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -102,7 +102,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index e911abba1eb..4f64783a926 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -159,8 +159,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 4c288670dc2..038b2617344 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -204,8 +204,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index 9d71b5eca56..cee1e510613 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -125,6 +125,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return SpecialistPoolServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> SpecialistPoolServiceTransport: """Returns the transport used by the client instance. @@ -240,7 +276,7 @@ async def create_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: @@ -332,7 +368,7 @@ async def get_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -407,7 +443,7 @@ async def list_specialist_pools( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -501,7 +537,7 @@ async def delete_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -598,7 +634,7 @@ async def update_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index 965b05b82bf..eb68ebde3b2 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -249,6 +249,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -299,57 +366,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -361,6 +393,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -427,7 +468,7 @@ def create_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: @@ -519,7 +560,7 @@ def get_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -594,7 +635,7 @@ def list_specialist_pools( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -688,7 +729,7 @@ def delete_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -785,7 +826,7 @@ def update_specialist_pool( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index 521dade5615..3cf9d59e6a2 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -104,7 +104,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 6676da58657..53148f48d75 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index 83469e9cbc8..72975cc593b 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py index 605fca1a77a..8e31a5017b5 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py @@ -16,7 +16,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +from typing import ( + Dict, + Optional, + AsyncIterable, + Awaitable, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources from google.api_core.client_options import ClientOptions @@ -144,6 +153,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return TensorboardServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> TensorboardServiceTransport: """Returns the transport used by the client instance. @@ -251,7 +296,7 @@ async def create_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: @@ -336,7 +381,7 @@ async def get_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -425,7 +470,7 @@ async def update_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: @@ -512,7 +557,7 @@ async def list_tensorboards( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -603,7 +648,7 @@ async def delete_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -705,7 +750,7 @@ async def create_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] @@ -787,7 +832,7 @@ async def get_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -877,7 +922,7 @@ async def update_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: @@ -959,7 +1004,7 @@ async def list_tensorboard_experiments( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1052,7 +1097,7 @@ async def delete_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1154,7 +1199,7 @@ async def create_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: @@ -1245,7 +1290,7 @@ async def batch_create_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1321,7 +1366,7 @@ async def get_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1408,7 +1453,7 @@ async def update_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: @@ -1488,7 +1533,7 @@ async def list_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1579,7 +1624,7 @@ async def delete_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1678,7 +1723,7 @@ async def batch_create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1762,7 +1807,7 @@ async def create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: @@ -1838,7 +1883,7 @@ async def get_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1926,7 +1971,7 @@ async def update_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: @@ -2013,7 +2058,7 @@ async def list_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2106,7 +2151,7 @@ async def delete_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2197,7 +2242,7 @@ async def batch_read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: @@ -2277,7 +2322,7 @@ async def read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2354,7 +2399,7 @@ def read_tensorboard_blob_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: @@ -2442,7 +2487,7 @@ async def write_tensorboard_experiment_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: @@ -2534,7 +2579,7 @@ async def write_tensorboard_run_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: @@ -2616,7 +2661,7 @@ async def export_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py index 81fa3a6665b..a513ce188e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py @@ -327,6 +327,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -377,57 +444,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, TensorboardServiceTransport): # transport is a TensorboardServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -439,6 +471,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -497,7 +538,7 @@ def create_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: @@ -582,7 +623,7 @@ def get_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -671,7 +712,7 @@ def update_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: @@ -758,7 +799,7 @@ def list_tensorboards( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -849,7 +890,7 @@ def delete_tensorboard( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -951,7 +992,7 @@ def create_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] @@ -1037,7 +1078,7 @@ def get_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1129,7 +1170,7 @@ def update_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: @@ -1215,7 +1256,7 @@ def list_tensorboard_experiments( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1312,7 +1353,7 @@ def delete_tensorboard_experiment( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1418,7 +1459,7 @@ def create_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: @@ -1509,7 +1550,7 @@ def batch_create_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -1589,7 +1630,7 @@ def get_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1676,7 +1717,7 @@ def update_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: @@ -1756,7 +1797,7 @@ def list_tensorboard_runs( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1847,7 +1888,7 @@ def delete_tensorboard_run( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1946,7 +1987,7 @@ def batch_create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: @@ -2036,7 +2077,7 @@ def create_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2116,7 +2157,7 @@ def get_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2206,7 +2247,7 @@ def update_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: @@ -2297,7 +2338,7 @@ def list_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -2394,7 +2435,7 @@ def delete_tensorboard_time_series( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -2489,7 +2530,7 @@ def batch_read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: @@ -2575,7 +2616,7 @@ def read_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: @@ -2656,7 +2697,7 @@ def read_tensorboard_blob_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: @@ -2746,7 +2787,7 @@ def write_tensorboard_experiment_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: @@ -2842,7 +2883,7 @@ def write_tensorboard_run_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: @@ -2926,7 +2967,7 @@ def export_tensorboard_time_series_data( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py index 96c05a9bcc1..f5d7a689ed0 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py @@ -114,7 +114,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py index f01df1056aa..d20bec45f8c 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py @@ -172,8 +172,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -246,7 +249,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py index 6a913aee880..35c9cad2d16 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py @@ -217,8 +217,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -248,7 +251,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 9ba5f46b147..4f4ff19fb18 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -16,7 +16,7 @@ from collections import OrderedDict import functools import re -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core.client_options import ClientOptions @@ -119,6 +119,42 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return VizierServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + @property def transport(self) -> VizierServiceTransport: """Returns the transport used by the client instance. @@ -224,7 +260,7 @@ async def create_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: @@ -297,7 +333,7 @@ async def get_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -373,7 +409,7 @@ async def list_studies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -445,7 +481,7 @@ async def delete_study( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -517,7 +553,7 @@ async def lookup_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -663,7 +699,7 @@ async def create_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: @@ -739,7 +775,7 @@ async def get_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -814,7 +850,7 @@ async def list_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -993,7 +1029,7 @@ async def delete_trial( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1184,7 +1220,7 @@ async def list_optimal_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index 94f01344003..f4b0dc15046 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -273,6 +273,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]: m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variabel is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + def __init__( self, *, @@ -323,57 +390,22 @@ def __init__( if client_options is None: client_options = client_options_lib.ClientOptions() - # Create SSL credentials for mutual TLS if needed. - if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ( - "true", - "false", - ): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true" + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options ) - client_cert_source_func = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - is_mtls = True - client_cert_source_func = client_options.client_cert_source - else: - is_mtls = mtls.has_default_client_cert_source() - if is_mtls: - client_cert_source_func = mtls.default_client_cert_source() - else: - client_cert_source_func = None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - if is_mtls: - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = self.DEFAULT_ENDPOINT - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " - "values: never, auto, always" - ) + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, VizierServiceTransport): # transport is a VizierServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials or client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." @@ -385,6 +417,15 @@ def __init__( ) self._transport = transport else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, @@ -442,7 +483,7 @@ def create_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: @@ -515,7 +556,7 @@ def get_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -591,7 +632,7 @@ def list_studies( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -663,7 +704,7 @@ def delete_study( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -735,7 +776,7 @@ def lookup_study( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -882,7 +923,7 @@ def create_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: @@ -958,7 +999,7 @@ def get_trial( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1033,7 +1074,7 @@ def list_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: @@ -1214,7 +1255,7 @@ def delete_trial( sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: @@ -1409,7 +1450,7 @@ def list_optimal_trials( """ # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have + # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py index f52d925d28f..a15615bc0c4 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py @@ -106,7 +106,6 @@ def __init__( credentials, _ = google.auth.load_credentials_from_file( credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) - elif credentials is None: credentials, _ = google.auth.default( **scopes_kwargs, quota_project_id=quota_project_id diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py index 5cf7cbaee1c..17972096187 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -167,8 +167,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -241,7 +244,7 @@ def operations_client(self) -> operations_v1.OperationsClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsClient(self.grpc_channel) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py index 7fdb740e870..7e7cddaad45 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -212,8 +212,11 @@ def __init__( if not self._grpc_channel: self._grpc_channel = type(self).create_channel( self._host, + # use the credentials which are saved credentials=self._credentials, - credentials_file=credentials_file, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, @@ -243,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: This property caches on the instance; repeated calls return the same client. """ - # Sanity check: Only create a new client if we do not already have one. + # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 4a89bb5682d..ae53601bbb3 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -53,8 +53,8 @@ class Annotation(proto.Message): Output only. Timestamp when this Annotation was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. annotation_source (google.cloud.aiplatform_v1beta1.types.UserActionReference): Output only. The source of the Annotation. diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index 9cd90b4dc1c..54b60d4bf33 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -43,8 +43,8 @@ class AnnotationSpec(proto.Message): Output only. Timestamp when AnnotationSpec was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1beta1/types/artifact.py b/google/cloud/aiplatform_v1beta1/types/artifact.py index d70d6ace76a..14f23a69a38 100644 --- a/google/cloud/aiplatform_v1beta1/types/artifact.py +++ b/google/cloud/aiplatform_v1beta1/types/artifact.py @@ -39,8 +39,8 @@ class Artifact(proto.Message): artifact file. May be empty if there is no actual artifact file. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1beta1.types.Artifact.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1beta1/types/context.py b/google/cloud/aiplatform_v1beta1/types/context.py index dcdd3dd2427..fc6aba8a2ba 100644 --- a/google/cloud/aiplatform_v1beta1/types/context.py +++ b/google/cloud/aiplatform_v1beta1/types/context.py @@ -35,8 +35,8 @@ class Context(proto.Message): User provided display name of the Context. May be up to 128 Unicode characters. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1beta1.types.Context.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index 8d43b2b4782..ccd0bccf9a7 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -60,8 +60,8 @@ class DataItem(proto.Message): schema's][google.cloud.aiplatform.v1beta1.Dataset.metadata_schema_uri] dataItemSchemaUri field. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index c510216e344..47bc7d1c09c 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -76,9 +76,9 @@ class DataLabelingJob(proto.Message): Google Cloud Storage describing the config for a specific type of DataLabelingJob. The schema files that can be used here are found in the - https://storage.googleapis.com/google-cloud- - aiplatform bucket in the - /schema/datalabelingjob/inputs/ folder. + https://storage.googleapis.com/google-cloud-aiplatform + bucket in the /schema/datalabelingjob/inputs/ + folder. inputs (google.protobuf.struct_pb2.Value): Required. Input config parameters for the DataLabelingJob. diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 1a3fc4fe8bd..cdf467cdd90 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -46,8 +46,7 @@ class Dataset(proto.Message): information about the Dataset. The schema is defined as an OpenAPI 3.0.2 Schema Object. The schema files that can be used here are found in - gs://google-cloud- - aiplatform/schema/dataset/metadata/. + gs://google-cloud-aiplatform/schema/dataset/metadata/. metadata (google.protobuf.struct_pb2.Value): Required. Additional information about the Dataset. @@ -82,8 +81,8 @@ class Dataset(proto.Message): title. encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): Customer-managed encryption key spec for a - Dataset. If set, this Dataset and all sub- - resources of this Dataset will be secured by + Dataset. If set, this Dataset and all + sub-resources of this Dataset will be secured by this key. """ diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 4e6981e44ee..e6813093363 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -80,9 +80,9 @@ class Endpoint(proto.Message): last updated. encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): Customer-managed encryption key spec for an - Endpoint. If set, this Endpoint and all sub- - resources of this Endpoint will be secured by - this key. + Endpoint. If set, this Endpoint and all + sub-resources of this Endpoint will be secured + by this key. network (str): The full name of the Google Compute Engine `network `__ diff --git a/google/cloud/aiplatform_v1beta1/types/entity_type.py b/google/cloud/aiplatform_v1beta1/types/entity_type.py index 6860f29ac4a..62b5c9aad35 100644 --- a/google/cloud/aiplatform_v1beta1/types/entity_type.py +++ b/google/cloud/aiplatform_v1beta1/types/entity_type.py @@ -63,8 +63,8 @@ class EntityType(proto.Message): System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Optional. Used to perform a consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform a consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): Optional. The default monitoring configuration for all diff --git a/google/cloud/aiplatform_v1beta1/types/execution.py b/google/cloud/aiplatform_v1beta1/types/execution.py index 85b824ac506..8562e75e72e 100644 --- a/google/cloud/aiplatform_v1beta1/types/execution.py +++ b/google/cloud/aiplatform_v1beta1/types/execution.py @@ -42,8 +42,8 @@ class Execution(proto.Message): and the system does not prescribe or check the validity of state transitions. etag (str): - An eTag used to perform consistent read- - odify-write updates. If not set, a blind + An eTag used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1beta1.types.Execution.LabelsEntry]): The labels with user-defined metadata to diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 2972aa21830..01aaa5e810c 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -282,10 +282,10 @@ class ExplanationParameters(proto.Message): This field is a member of `oneof`_ ``method``. integrated_gradients_attribution (google.cloud.aiplatform_v1beta1.types.IntegratedGradientsAttribution): - An attribution method that computes Aumann- - hapley values taking advantage of the model's - fully differentiable structure. Refer to this - paper for more details: + An attribution method that computes + Aumann-Shapley values taking advantage of the + model's fully differentiable structure. Refer to + this paper for more details: https://arxiv.org/abs/1703.01365 This field is a member of `oneof`_ ``method``. diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index a17467a9201..9f36ff0d775 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -100,8 +100,8 @@ class InputMetadata(proto.Message): [instance_schema_uri][google.cloud.aiplatform.v1beta1.PredictSchemata.instance_schema_uri]. input_tensor_name (str): Name of the input tensor for this feature. - Required and is only applicable to Vertex AI- - provided images for Tensorflow. + Required and is only applicable to Vertex + AI-provided images for Tensorflow. encoding (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Encoding): Defines how the feature is encoded into the input tensor. Defaults to IDENTITY. diff --git a/google/cloud/aiplatform_v1beta1/types/feature.py b/google/cloud/aiplatform_v1beta1/types/feature.py index 7e056694fe2..13adb8b7f20 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature.py +++ b/google/cloud/aiplatform_v1beta1/types/feature.py @@ -65,9 +65,9 @@ class Feature(proto.Message): System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): Optional. The custom monitoring configuration for this Feature, if not set, use the monitoring_config defined for diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore.py b/google/cloud/aiplatform_v1beta1/types/featurestore.py index 203dfc4a68a..5d19c1dee71 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore.py @@ -40,8 +40,8 @@ class Featurestore(proto.Message): Output only. Timestamp when this Featurestore was last updated. etag (str): - Optional. Used to perform consistent read- - odify-write updates. If not set, a blind + Optional. Used to perform consistent + read-modify-write updates. If not set, a blind "overwrite" update happens. labels (Sequence[google.cloud.aiplatform_v1beta1.types.Featurestore.LabelsEntry]): Optional. The labels with user-defined diff --git a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py index b9a106c6392..146fba7fd63 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py @@ -158,9 +158,10 @@ class ModelDeploymentMonitoringJob(proto.Message): encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): Customer-managed encryption key spec for a ModelDeploymentMonitoringJob. If set, this - ModelDeploymentMonitoringJob and all sub- - resources of this ModelDeploymentMonitoringJob - will be secured by this key. + ModelDeploymentMonitoringJob and all + sub-resources of this + ModelDeploymentMonitoringJob will be secured by + this key. enable_monitoring_pipeline_logs (bool): If true, the scheduled monitoring pipeline logs are sent to Google Cloud Logging, including pipeline status and diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard.py b/google/cloud/aiplatform_v1beta1/types/tensorboard.py index aa04ae2551a..a8bfafaa4fb 100644 --- a/google/cloud/aiplatform_v1beta1/types/tensorboard.py +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard.py @@ -74,9 +74,9 @@ class Tensorboard(proto.Message): keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. """ name = proto.Field(proto.STRING, number=1,) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py index 1b6e250cc42..f072b08f007 100644 --- a/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py @@ -68,9 +68,9 @@ class TensorboardRun(proto.Message): of labels. System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. """ name = proto.Field(proto.STRING, number=1,) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py index 130d73f2661..a2d2e67f39e 100644 --- a/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py @@ -49,9 +49,9 @@ class TensorboardTimeSeries(proto.Message): Output only. Timestamp when this TensorboardTimeSeries was last updated. etag (str): - Used to perform a consistent read-modify- - rite updates. If not set, a blind "overwrite" - update happens. + Used to perform a consistent + read-modify-write updates. If not set, a blind + "overwrite" update happens. plugin_name (str): Immutable. Name of the plugin this time series pertain to. Such as Scalar, Tensor, Blob diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 6b9ee8c4dca..5c389b19fe9 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -69,13 +69,13 @@ class TrainingPipeline(proto.Message): is responsible for producing the model artifact, and may also include additional auxiliary work. The definition files that can be used here are - found in gs://google-cloud- - aiplatform/schema/trainingjob/definition/. Note: - The URI given on output will be immutable and - probably different, including the URI scheme, - than the one given on input. The output URI will - point to a location where the user only has a - read access. + found in + gs://google-cloud-aiplatform/schema/trainingjob/definition/. + Note: The URI given on output will be immutable + and probably different, including the URI + scheme, than the one given on input. The output + URI will point to a location where the user only + has a read access. training_task_inputs (google.protobuf.struct_pb2.Value): Required. The training task's parameter(s), as specified in the diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index f329e5dd81a..b8980f47209 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -263,20 +264,20 @@ def test_dataset_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -335,7 +336,7 @@ def test_dataset_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -412,6 +413,87 @@ def test_dataset_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -430,7 +512,7 @@ def test_dataset_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -444,24 +526,31 @@ def test_dataset_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceClient, + transports.DatasetServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -473,6 +562,35 @@ def test_dataset_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_dataset_service_client_client_options_from_dict(): with mock.patch( @@ -494,9 +612,8 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.CreateDatasetRequest, dict,]) +def test_create_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -520,10 +637,6 @@ def test_create_dataset( assert isinstance(response, future.Future) -def test_create_dataset_from_dict(): - test_create_dataset(request_type=dict) - - def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -712,9 +825,8 @@ async def test_create_dataset_flattened_error_async(): ) -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.GetDatasetRequest, dict,]) +def test_get_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -749,10 +861,6 @@ def test_get_dataset( assert response.etag == "etag_value" -def test_get_dataset_from_dict(): - test_get_dataset(request_type=dict) - - def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -934,9 +1042,8 @@ async def test_get_dataset_flattened_error_async(): ) -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.UpdateDatasetRequest, dict,]) +def test_update_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -971,10 +1078,6 @@ def test_update_dataset( assert response.etag == "etag_value" -def test_update_dataset_from_dict(): - test_update_dataset(request_type=dict) - - def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1176,9 +1279,8 @@ async def test_update_dataset_flattened_error_async(): ) -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ListDatasetsRequest, dict,]) +def test_list_datasets(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1205,10 +1307,6 @@ def test_list_datasets( assert response.next_page_token == "next_page_token_value" -def test_list_datasets_from_dict(): - test_list_datasets(request_type=dict) - - def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1386,8 +1484,10 @@ async def test_list_datasets_flattened_error_async(): ) -def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_datasets_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: @@ -1420,8 +1520,10 @@ def test_list_datasets_pager(): assert all(isinstance(i, dataset.Dataset) for i in results) -def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_datasets_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: @@ -1508,9 +1610,8 @@ async def test_list_datasets_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.DeleteDatasetRequest, dict,]) +def test_delete_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1534,10 +1635,6 @@ def test_delete_dataset( assert isinstance(response, future.Future) -def test_delete_dataset_from_dict(): - test_delete_dataset(request_type=dict) - - def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1712,9 +1809,8 @@ async def test_delete_dataset_flattened_error_async(): ) -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ImportDataRequest, dict,]) +def test_import_data(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1738,10 +1834,6 @@ def test_import_data( assert isinstance(response, future.Future) -def test_import_data_from_dict(): - test_import_data(request_type=dict) - - def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1944,9 +2036,8 @@ async def test_import_data_flattened_error_async(): ) -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ExportDataRequest, dict,]) +def test_export_data(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1970,10 +2061,6 @@ def test_export_data( assert isinstance(response, future.Future) -def test_export_data_from_dict(): - test_export_data(request_type=dict) - - def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2188,9 +2275,8 @@ async def test_export_data_flattened_error_async(): ) -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ListDataItemsRequest, dict,]) +def test_list_data_items(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2217,10 +2303,6 @@ def test_list_data_items( assert response.next_page_token == "next_page_token_value" -def test_list_data_items_from_dict(): - test_list_data_items(request_type=dict) - - def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2398,8 +2480,10 @@ async def test_list_data_items_flattened_error_async(): ) -def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_items_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: @@ -2438,8 +2522,10 @@ def test_list_data_items_pager(): assert all(isinstance(i, data_item.DataItem) for i in results) -def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_items_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: @@ -2544,9 +2630,10 @@ async def test_list_data_items_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): +@pytest.mark.parametrize( + "request_type", [dataset_service.GetAnnotationSpecRequest, dict,] +) +def test_get_annotation_spec(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2577,10 +2664,6 @@ def test_get_annotation_spec( assert response.etag == "etag_value" -def test_get_annotation_spec_from_dict(): - test_get_annotation_spec(request_type=dict) - - def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2773,9 +2856,10 @@ async def test_get_annotation_spec_flattened_error_async(): ) -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): +@pytest.mark.parametrize( + "request_type", [dataset_service.ListAnnotationsRequest, dict,] +) +def test_list_annotations(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2802,10 +2886,6 @@ def test_list_annotations( assert response.next_page_token == "next_page_token_value" -def test_list_annotations_from_dict(): - test_list_annotations(request_type=dict) - - def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2983,8 +3063,10 @@ async def test_list_annotations_flattened_error_async(): ) -def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_annotations_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: @@ -3023,8 +3105,10 @@ def test_list_annotations_pager(): assert all(isinstance(i, annotation.Annotation) for i in results) -def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_annotations_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: @@ -3149,6 +3233,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatasetServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.DatasetServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3774,7 +3875,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3839,3 +3940,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 3b92c2af95c..a09ea225079 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -263,20 +264,20 @@ def test_endpoint_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -345,7 +346,7 @@ def test_endpoint_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -422,6 +423,87 @@ def test_endpoint_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -440,7 +522,7 @@ def test_endpoint_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -454,24 +536,31 @@ def test_endpoint_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -483,6 +572,35 @@ def test_endpoint_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_endpoint_service_client_client_options_from_dict(): with mock.patch( @@ -504,9 +622,10 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.CreateEndpointRequest, dict,] +) +def test_create_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -530,10 +649,6 @@ def test_create_endpoint( assert isinstance(response, future.Future) -def test_create_endpoint_from_dict(): - test_create_endpoint(request_type=dict) - - def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -734,9 +849,8 @@ async def test_create_endpoint_flattened_error_async(): ) -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.GetEndpointRequest, dict,]) +def test_get_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -778,10 +892,6 @@ def test_get_endpoint( ) -def test_get_endpoint_from_dict(): - test_get_endpoint(request_type=dict) - - def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -970,9 +1080,8 @@ async def test_get_endpoint_flattened_error_async(): ) -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.ListEndpointsRequest, dict,]) +def test_list_endpoints(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -999,10 +1108,6 @@ def test_list_endpoints( assert response.next_page_token == "next_page_token_value" -def test_list_endpoints_from_dict(): - test_list_endpoints(request_type=dict) - - def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1180,8 +1285,10 @@ async def test_list_endpoints_flattened_error_async(): ) -def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_endpoints_pager(transport_name: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: @@ -1220,8 +1327,10 @@ def test_list_endpoints_pager(): assert all(isinstance(i, endpoint.Endpoint) for i in results) -def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_endpoints_pages(transport_name: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: @@ -1330,9 +1439,10 @@ async def test_list_endpoints_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.UpdateEndpointRequest, dict,] +) +def test_update_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1374,10 +1484,6 @@ def test_update_endpoint( ) -def test_update_endpoint_from_dict(): - test_update_endpoint(request_type=dict) - - def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1590,9 +1696,10 @@ async def test_update_endpoint_flattened_error_async(): ) -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.DeleteEndpointRequest, dict,] +) +def test_delete_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1616,10 +1723,6 @@ def test_delete_endpoint( assert isinstance(response, future.Future) -def test_delete_endpoint_from_dict(): - test_delete_endpoint(request_type=dict) - - def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1794,9 +1897,8 @@ async def test_delete_endpoint_flattened_error_async(): ) -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.DeployModelRequest, dict,]) +def test_deploy_model(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1820,10 +1922,6 @@ def test_deploy_model( assert isinstance(response, future.Future) -def test_deploy_model_from_dict(): - test_deploy_model(request_type=dict) - - def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2060,9 +2158,8 @@ async def test_deploy_model_flattened_error_async(): ) -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.UndeployModelRequest, dict,]) +def test_undeploy_model(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2086,10 +2183,6 @@ def test_undeploy_model( assert isinstance(response, future.Future) -def test_undeploy_model_from_dict(): - test_undeploy_model(request_type=dict) - - def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2310,6 +2403,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EndpointServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.EndpointServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2915,7 +3025,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2980,3 +3090,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py index 2a5deb42653..eb470a91df9 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py @@ -284,20 +284,20 @@ def test_featurestore_online_serving_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -366,7 +366,7 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -443,6 +443,93 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", + [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, + ], +) +@mock.patch.object( + FeaturestoreOnlineServingServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceClient), +) +@mock.patch.object( + FeaturestoreOnlineServingServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient), +) +def test_featurestore_online_serving_service_client_get_mtls_endpoint_and_cert_source( + client_class, +): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -465,7 +552,7 @@ def test_featurestore_online_serving_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -479,28 +566,31 @@ def test_featurestore_online_serving_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", + grpc_helpers, ), ( FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_featurestore_online_serving_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -512,6 +602,35 @@ def test_featurestore_online_serving_service_client_client_options_credentials_f always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_featurestore_online_serving_service_client_client_options_from_dict(): with mock.patch( @@ -533,10 +652,10 @@ def test_featurestore_online_serving_service_client_client_options_from_dict(): ) -def test_read_feature_values( - transport: str = "grpc", - request_type=featurestore_online_service.ReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_online_service.ReadFeatureValuesRequest, dict,] +) +def test_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreOnlineServingServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -562,10 +681,6 @@ def test_read_feature_values( assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) -def test_read_feature_values_from_dict(): - test_read_feature_values(request_type=dict) - - def test_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -761,10 +876,11 @@ async def test_read_feature_values_flattened_error_async(): ) -def test_streaming_read_feature_values( - transport: str = "grpc", - request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", + [featurestore_online_service.StreamingReadFeatureValuesRequest, dict,], +) +def test_streaming_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreOnlineServingServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -797,10 +913,6 @@ def test_streaming_read_feature_values( ) -def test_streaming_read_feature_values_from_dict(): - test_streaming_read_feature_values(request_type=dict) - - def test_streaming_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1029,6 +1141,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1554,7 +1685,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1619,3 +1750,39 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py index 90402f0bcd1..85357a09ab7 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -271,20 +272,20 @@ def test_featurestore_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -353,7 +354,7 @@ def test_featurestore_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -430,6 +431,87 @@ def test_featurestore_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [FeaturestoreServiceClient, FeaturestoreServiceAsyncClient] +) +@mock.patch.object( + FeaturestoreServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceClient), +) +@mock.patch.object( + FeaturestoreServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceAsyncClient), +) +def test_featurestore_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -452,7 +534,7 @@ def test_featurestore_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -466,28 +548,31 @@ def test_featurestore_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", + grpc_helpers, ), ( FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_featurestore_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -499,6 +584,35 @@ def test_featurestore_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_featurestore_service_client_client_options_from_dict(): with mock.patch( @@ -520,9 +634,10 @@ def test_featurestore_service_client_client_options_from_dict(): ) -def test_create_featurestore( - transport: str = "grpc", request_type=featurestore_service.CreateFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateFeaturestoreRequest, dict,] +) +def test_create_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -548,10 +663,6 @@ def test_create_featurestore( assert isinstance(response, future.Future) -def test_create_featurestore_from_dict(): - test_create_featurestore(request_type=dict) - - def test_create_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -771,9 +882,10 @@ async def test_create_featurestore_flattened_error_async(): ) -def test_get_featurestore( - transport: str = "grpc", request_type=featurestore_service.GetFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetFeaturestoreRequest, dict,] +) +def test_get_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -804,10 +916,6 @@ def test_get_featurestore( assert response.state == featurestore.Featurestore.State.STABLE -def test_get_featurestore_from_dict(): - test_get_featurestore(request_type=dict) - - def test_get_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -996,9 +1104,10 @@ async def test_get_featurestore_flattened_error_async(): ) -def test_list_featurestores( - transport: str = "grpc", request_type=featurestore_service.ListFeaturestoresRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListFeaturestoresRequest, dict,] +) +def test_list_featurestores(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1027,10 +1136,6 @@ def test_list_featurestores( assert response.next_page_token == "next_page_token_value" -def test_list_featurestores_from_dict(): - test_list_featurestores(request_type=dict) - - def test_list_featurestores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1227,8 +1332,10 @@ async def test_list_featurestores_flattened_error_async(): ) -def test_list_featurestores_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_featurestores_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1272,8 +1379,10 @@ def test_list_featurestores_pager(): assert all(isinstance(i, featurestore.Featurestore) for i in results) -def test_list_featurestores_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_featurestores_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1397,9 +1506,10 @@ async def test_list_featurestores_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_featurestore( - transport: str = "grpc", request_type=featurestore_service.UpdateFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateFeaturestoreRequest, dict,] +) +def test_update_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1425,10 +1535,6 @@ def test_update_featurestore( assert isinstance(response, future.Future) -def test_update_featurestore_from_dict(): - test_update_featurestore(request_type=dict) - - def test_update_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1644,9 +1750,10 @@ async def test_update_featurestore_flattened_error_async(): ) -def test_delete_featurestore( - transport: str = "grpc", request_type=featurestore_service.DeleteFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteFeaturestoreRequest, dict,] +) +def test_delete_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1672,10 +1779,6 @@ def test_delete_featurestore( assert isinstance(response, future.Future) -def test_delete_featurestore_from_dict(): - test_delete_featurestore(request_type=dict) - - def test_delete_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1881,9 +1984,10 @@ async def test_delete_featurestore_flattened_error_async(): ) -def test_create_entity_type( - transport: str = "grpc", request_type=featurestore_service.CreateEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateEntityTypeRequest, dict,] +) +def test_create_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1909,10 +2013,6 @@ def test_create_entity_type( assert isinstance(response, future.Future) -def test_create_entity_type_from_dict(): - test_create_entity_type(request_type=dict) - - def test_create_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2132,9 +2232,10 @@ async def test_create_entity_type_flattened_error_async(): ) -def test_get_entity_type( - transport: str = "grpc", request_type=featurestore_service.GetEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetEntityTypeRequest, dict,] +) +def test_get_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2163,10 +2264,6 @@ def test_get_entity_type( assert response.etag == "etag_value" -def test_get_entity_type_from_dict(): - test_get_entity_type(request_type=dict) - - def test_get_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2353,9 +2450,10 @@ async def test_get_entity_type_flattened_error_async(): ) -def test_list_entity_types( - transport: str = "grpc", request_type=featurestore_service.ListEntityTypesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListEntityTypesRequest, dict,] +) +def test_list_entity_types(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2384,10 +2482,6 @@ def test_list_entity_types( assert response.next_page_token == "next_page_token_value" -def test_list_entity_types_from_dict(): - test_list_entity_types(request_type=dict) - - def test_list_entity_types_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2584,8 +2678,10 @@ async def test_list_entity_types_flattened_error_async(): ) -def test_list_entity_types_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_entity_types_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2626,8 +2722,10 @@ def test_list_entity_types_pager(): assert all(isinstance(i, entity_type.EntityType) for i in results) -def test_list_entity_types_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_entity_types_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2742,9 +2840,10 @@ async def test_list_entity_types_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_entity_type( - transport: str = "grpc", request_type=featurestore_service.UpdateEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateEntityTypeRequest, dict,] +) +def test_update_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2775,10 +2874,6 @@ def test_update_entity_type( assert response.etag == "etag_value" -def test_update_entity_type_from_dict(): - test_update_entity_type(request_type=dict) - - def test_update_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2997,9 +3092,10 @@ async def test_update_entity_type_flattened_error_async(): ) -def test_delete_entity_type( - transport: str = "grpc", request_type=featurestore_service.DeleteEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteEntityTypeRequest, dict,] +) +def test_delete_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3025,10 +3121,6 @@ def test_delete_entity_type( assert isinstance(response, future.Future) -def test_delete_entity_type_from_dict(): - test_delete_entity_type(request_type=dict) - - def test_delete_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3234,9 +3326,10 @@ async def test_delete_entity_type_flattened_error_async(): ) -def test_create_feature( - transport: str = "grpc", request_type=featurestore_service.CreateFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateFeatureRequest, dict,] +) +def test_create_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3260,10 +3353,6 @@ def test_create_feature( assert isinstance(response, future.Future) -def test_create_feature_from_dict(): - test_create_feature(request_type=dict) - - def test_create_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3471,10 +3560,10 @@ async def test_create_feature_flattened_error_async(): ) -def test_batch_create_features( - transport: str = "grpc", - request_type=featurestore_service.BatchCreateFeaturesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.BatchCreateFeaturesRequest, dict,] +) +def test_batch_create_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3500,10 +3589,6 @@ def test_batch_create_features( assert isinstance(response, future.Future) -def test_batch_create_features_from_dict(): - test_batch_create_features(request_type=dict) - - def test_batch_create_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3713,9 +3798,10 @@ async def test_batch_create_features_flattened_error_async(): ) -def test_get_feature( - transport: str = "grpc", request_type=featurestore_service.GetFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetFeatureRequest, dict,] +) +def test_get_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3748,10 +3834,6 @@ def test_get_feature( assert response.etag == "etag_value" -def test_get_feature_from_dict(): - test_get_feature(request_type=dict) - - def test_get_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3937,9 +4019,10 @@ async def test_get_feature_flattened_error_async(): ) -def test_list_features( - transport: str = "grpc", request_type=featurestore_service.ListFeaturesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListFeaturesRequest, dict,] +) +def test_list_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3966,10 +4049,6 @@ def test_list_features( assert response.next_page_token == "next_page_token_value" -def test_list_features_from_dict(): - test_list_features(request_type=dict) - - def test_list_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4154,8 +4233,10 @@ async def test_list_features_flattened_error_async(): ) -def test_list_features_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_features_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: @@ -4190,8 +4271,10 @@ def test_list_features_pager(): assert all(isinstance(i, feature.Feature) for i in results) -def test_list_features_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_features_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: @@ -4288,9 +4371,10 @@ async def test_list_features_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_feature( - transport: str = "grpc", request_type=featurestore_service.UpdateFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateFeatureRequest, dict,] +) +def test_update_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4323,10 +4407,6 @@ def test_update_feature( assert response.etag == "etag_value" -def test_update_feature_from_dict(): - test_update_feature(request_type=dict) - - def test_update_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4533,9 +4613,10 @@ async def test_update_feature_flattened_error_async(): ) -def test_delete_feature( - transport: str = "grpc", request_type=featurestore_service.DeleteFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteFeatureRequest, dict,] +) +def test_delete_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4559,10 +4640,6 @@ def test_delete_feature( assert isinstance(response, future.Future) -def test_delete_feature_from_dict(): - test_delete_feature(request_type=dict) - - def test_delete_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4744,10 +4821,10 @@ async def test_delete_feature_flattened_error_async(): ) -def test_import_feature_values( - transport: str = "grpc", - request_type=featurestore_service.ImportFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ImportFeatureValuesRequest, dict,] +) +def test_import_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4773,10 +4850,6 @@ def test_import_feature_values( assert isinstance(response, future.Future) -def test_import_feature_values_from_dict(): - test_import_feature_values(request_type=dict) - - def test_import_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4972,10 +5045,10 @@ async def test_import_feature_values_flattened_error_async(): ) -def test_batch_read_feature_values( - transport: str = "grpc", - request_type=featurestore_service.BatchReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.BatchReadFeatureValuesRequest, dict,] +) +def test_batch_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5001,10 +5074,6 @@ def test_batch_read_feature_values( assert isinstance(response, future.Future) -def test_batch_read_feature_values_from_dict(): - test_batch_read_feature_values(request_type=dict) - - def test_batch_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5206,10 +5275,10 @@ async def test_batch_read_feature_values_flattened_error_async(): ) -def test_export_feature_values( - transport: str = "grpc", - request_type=featurestore_service.ExportFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ExportFeatureValuesRequest, dict,] +) +def test_export_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5235,10 +5304,6 @@ def test_export_feature_values( assert isinstance(response, future.Future) -def test_export_feature_values_from_dict(): - test_export_feature_values(request_type=dict) - - def test_export_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5434,9 +5499,10 @@ async def test_export_feature_values_flattened_error_async(): ) -def test_search_features( - transport: str = "grpc", request_type=featurestore_service.SearchFeaturesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.SearchFeaturesRequest, dict,] +) +def test_search_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5463,10 +5529,6 @@ def test_search_features( assert response.next_page_token == "next_page_token_value" -def test_search_features_from_dict(): - test_search_features(request_type=dict) - - def test_search_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5665,8 +5727,10 @@ async def test_search_features_flattened_error_async(): ) -def test_search_features_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_features_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: @@ -5701,8 +5765,10 @@ def test_search_features_pager(): assert all(isinstance(i, feature.Feature) for i in results) -def test_search_features_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_features_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: @@ -5819,6 +5885,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.FeaturestoreServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -6434,7 +6517,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -6499,3 +6582,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py index 2917a37ed9b..2bd62d992eb 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -42,6 +43,7 @@ ) from google.cloud.aiplatform_v1.services.index_endpoint_service import pagers from google.cloud.aiplatform_v1.services.index_endpoint_service import transports +from google.cloud.aiplatform_v1.types import accelerator_type from google.cloud.aiplatform_v1.types import index_endpoint from google.cloud.aiplatform_v1.types import index_endpoint as gca_index_endpoint from google.cloud.aiplatform_v1.types import index_endpoint_service @@ -265,20 +267,20 @@ def test_index_endpoint_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -347,7 +349,7 @@ def test_index_endpoint_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -424,6 +426,87 @@ def test_index_endpoint_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [IndexEndpointServiceClient, IndexEndpointServiceAsyncClient] +) +@mock.patch.object( + IndexEndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceClient), +) +@mock.patch.object( + IndexEndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceAsyncClient), +) +def test_index_endpoint_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -446,7 +529,7 @@ def test_index_endpoint_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -460,28 +543,31 @@ def test_index_endpoint_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", + grpc_helpers, ), ( IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_index_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -493,6 +579,35 @@ def test_index_endpoint_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_index_endpoint_service_client_client_options_from_dict(): with mock.patch( @@ -514,10 +629,10 @@ def test_index_endpoint_service_client_client_options_from_dict(): ) -def test_create_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.CreateIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.CreateIndexEndpointRequest, dict,] +) +def test_create_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -543,10 +658,6 @@ def test_create_index_endpoint( assert isinstance(response, future.Future) -def test_create_index_endpoint_from_dict(): - test_create_index_endpoint(request_type=dict) - - def test_create_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -756,9 +867,10 @@ async def test_create_index_endpoint_flattened_error_async(): ) -def test_get_index_endpoint( - transport: str = "grpc", request_type=index_endpoint_service.GetIndexEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.GetIndexEndpointRequest, dict,] +) +def test_get_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -797,10 +909,6 @@ def test_get_index_endpoint( assert response.enable_private_service_connect is True -def test_get_index_endpoint_from_dict(): - test_get_index_endpoint(request_type=dict) - - def test_get_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1007,10 +1115,10 @@ async def test_get_index_endpoint_flattened_error_async(): ) -def test_list_index_endpoints( - transport: str = "grpc", - request_type=index_endpoint_service.ListIndexEndpointsRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.ListIndexEndpointsRequest, dict,] +) +def test_list_index_endpoints(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1039,10 +1147,6 @@ def test_list_index_endpoints( assert response.next_page_token == "next_page_token_value" -def test_list_index_endpoints_from_dict(): - test_list_index_endpoints(request_type=dict) - - def test_list_index_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1239,9 +1343,9 @@ async def test_list_index_endpoints_flattened_error_async(): ) -def test_list_index_endpoints_pager(): +def test_list_index_endpoints_pager(transport_name: str = "grpc"): client = IndexEndpointServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1287,9 +1391,9 @@ def test_list_index_endpoints_pager(): assert all(isinstance(i, index_endpoint.IndexEndpoint) for i in results) -def test_list_index_endpoints_pages(): +def test_list_index_endpoints_pages(transport_name: str = "grpc"): client = IndexEndpointServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1417,10 +1521,10 @@ async def test_list_index_endpoints_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.UpdateIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.UpdateIndexEndpointRequest, dict,] +) +def test_update_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1459,10 +1563,6 @@ def test_update_index_endpoint( assert response.enable_private_service_connect is True -def test_update_index_endpoint_from_dict(): - test_update_index_endpoint(request_type=dict) - - def test_update_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1691,10 +1791,10 @@ async def test_update_index_endpoint_flattened_error_async(): ) -def test_delete_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.DeleteIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.DeleteIndexEndpointRequest, dict,] +) +def test_delete_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1720,10 +1820,6 @@ def test_delete_index_endpoint( assert isinstance(response, future.Future) -def test_delete_index_endpoint_from_dict(): - test_delete_index_endpoint(request_type=dict) - - def test_delete_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1917,9 +2013,10 @@ async def test_delete_index_endpoint_flattened_error_async(): ) -def test_deploy_index( - transport: str = "grpc", request_type=index_endpoint_service.DeployIndexRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.DeployIndexRequest, dict,] +) +def test_deploy_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1943,10 +2040,6 @@ def test_deploy_index( assert isinstance(response, future.Future) -def test_deploy_index_from_dict(): - test_deploy_index(request_type=dict) - - def test_deploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2148,9 +2241,10 @@ async def test_deploy_index_flattened_error_async(): ) -def test_undeploy_index( - transport: str = "grpc", request_type=index_endpoint_service.UndeployIndexRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.UndeployIndexRequest, dict,] +) +def test_undeploy_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2174,10 +2268,6 @@ def test_undeploy_index( assert isinstance(response, future.Future) -def test_undeploy_index_from_dict(): - test_undeploy_index(request_type=dict) - - def test_undeploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2379,10 +2469,10 @@ async def test_undeploy_index_flattened_error_async(): ) -def test_mutate_deployed_index( - transport: str = "grpc", - request_type=index_endpoint_service.MutateDeployedIndexRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.MutateDeployedIndexRequest, dict,] +) +def test_mutate_deployed_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2408,10 +2498,6 @@ def test_mutate_deployed_index( assert isinstance(response, future.Future) -def test_mutate_deployed_index_from_dict(): - test_mutate_deployed_index(request_type=dict) - - def test_mutate_deployed_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2645,6 +2731,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.IndexEndpointServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3207,7 +3312,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3272,3 +3377,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_index_service.py b/tests/unit/gapic/aiplatform_v1/test_index_service.py index 56ace1aee2b..d6d0a9a017f 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -248,20 +249,20 @@ def test_index_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -318,7 +319,7 @@ def test_index_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -395,6 +396,83 @@ def test_index_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [IndexServiceClient, IndexServiceAsyncClient]) +@mock.patch.object( + IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient) +) +@mock.patch.object( + IndexServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexServiceAsyncClient), +) +def test_index_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -413,7 +491,7 @@ def test_index_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -427,24 +505,31 @@ def test_index_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + ( + IndexServiceClient, + transports.IndexServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_index_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -456,6 +541,35 @@ def test_index_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_index_service_client_client_options_from_dict(): with mock.patch( @@ -475,9 +589,8 @@ def test_index_service_client_client_options_from_dict(): ) -def test_create_index( - transport: str = "grpc", request_type=index_service.CreateIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.CreateIndexRequest, dict,]) +def test_create_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -501,10 +614,6 @@ def test_create_index( assert isinstance(response, future.Future) -def test_create_index_from_dict(): - test_create_index(request_type=dict) - - def test_create_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -687,7 +796,8 @@ async def test_create_index_flattened_error_async(): ) -def test_get_index(transport: str = "grpc", request_type=index_service.GetIndexRequest): +@pytest.mark.parametrize("request_type", [index_service.GetIndexRequest, dict,]) +def test_get_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -722,10 +832,6 @@ def test_get_index(transport: str = "grpc", request_type=index_service.GetIndexR assert response.etag == "etag_value" -def test_get_index_from_dict(): - test_get_index(request_type=dict) - - def test_get_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -901,9 +1007,8 @@ async def test_get_index_flattened_error_async(): ) -def test_list_indexes( - transport: str = "grpc", request_type=index_service.ListIndexesRequest -): +@pytest.mark.parametrize("request_type", [index_service.ListIndexesRequest, dict,]) +def test_list_indexes(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -930,10 +1035,6 @@ def test_list_indexes( assert response.next_page_token == "next_page_token_value" -def test_list_indexes_from_dict(): - test_list_indexes(request_type=dict) - - def test_list_indexes_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1103,8 +1204,10 @@ async def test_list_indexes_flattened_error_async(): ) -def test_list_indexes_pager(): - client = IndexServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_indexes_pager(transport_name: str = "grpc"): + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -1135,8 +1238,10 @@ def test_list_indexes_pager(): assert all(isinstance(i, index.Index) for i in results) -def test_list_indexes_pages(): - client = IndexServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_indexes_pages(transport_name: str = "grpc"): + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -1217,9 +1322,8 @@ async def test_list_indexes_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_index( - transport: str = "grpc", request_type=index_service.UpdateIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.UpdateIndexRequest, dict,]) +def test_update_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1243,10 +1347,6 @@ def test_update_index( assert isinstance(response, future.Future) -def test_update_index_from_dict(): - test_update_index(request_type=dict) - - def test_update_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1431,9 +1531,8 @@ async def test_update_index_flattened_error_async(): ) -def test_delete_index( - transport: str = "grpc", request_type=index_service.DeleteIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.DeleteIndexRequest, dict,]) +def test_delete_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1457,10 +1556,6 @@ def test_delete_index( assert isinstance(response, future.Future) -def test_delete_index_from_dict(): - test_delete_index(request_type=dict) - - def test_delete_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1649,6 +1744,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.IndexServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2191,7 +2303,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2256,3 +2368,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index 0a0221dc282..3bd3c6d8a65 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -279,20 +280,20 @@ def test_job_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -349,7 +350,7 @@ def test_job_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -426,6 +427,83 @@ def test_job_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -444,7 +522,7 @@ def test_job_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -458,24 +536,26 @@ def test_job_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", grpc_helpers), ( JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -487,6 +567,35 @@ def test_job_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_job_service_client_client_options_from_dict(): with mock.patch( @@ -506,9 +615,8 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.CreateCustomJobRequest, dict,]) +def test_create_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -541,10 +649,6 @@ def test_create_custom_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_custom_job_from_dict(): - test_create_custom_job(request_type=dict) - - def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -748,9 +852,8 @@ async def test_create_custom_job_flattened_error_async(): ) -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.GetCustomJobRequest, dict,]) +def test_get_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -781,10 +884,6 @@ def test_get_custom_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_custom_job_from_dict(): - test_get_custom_job(request_type=dict) - - def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -960,9 +1059,8 @@ async def test_get_custom_job_flattened_error_async(): ) -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): +@pytest.mark.parametrize("request_type", [job_service.ListCustomJobsRequest, dict,]) +def test_list_custom_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -989,10 +1087,6 @@ def test_list_custom_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_custom_jobs_from_dict(): - test_list_custom_jobs(request_type=dict) - - def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1162,8 +1256,10 @@ async def test_list_custom_jobs_flattened_error_async(): ) -def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_custom_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: @@ -1200,8 +1296,10 @@ def test_list_custom_jobs_pager(): assert all(isinstance(i, custom_job.CustomJob) for i in results) -def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_custom_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: @@ -1300,9 +1398,8 @@ async def test_list_custom_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.DeleteCustomJobRequest, dict,]) +def test_delete_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1328,10 +1425,6 @@ def test_delete_custom_job( assert isinstance(response, future.Future) -def test_delete_custom_job_from_dict(): - test_delete_custom_job(request_type=dict) - - def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1512,9 +1605,8 @@ async def test_delete_custom_job_flattened_error_async(): ) -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.CancelCustomJobRequest, dict,]) +def test_cancel_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1540,10 +1632,6 @@ def test_cancel_custom_job( assert response is None -def test_cancel_custom_job_from_dict(): - test_cancel_custom_job(request_type=dict) - - def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1718,9 +1806,10 @@ async def test_cancel_custom_job_flattened_error_async(): ) -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateDataLabelingJobRequest, dict,] +) +def test_create_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1765,10 +1854,6 @@ def test_create_data_labeling_job( assert response.specialist_pools == ["specialist_pools_value"] -def test_create_data_labeling_job_from_dict(): - test_create_data_labeling_job(request_type=dict) - - def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1985,9 +2070,8 @@ async def test_create_data_labeling_job_flattened_error_async(): ) -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.GetDataLabelingJobRequest, dict,]) +def test_get_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2032,10 +2116,6 @@ def test_get_data_labeling_job( assert response.specialist_pools == ["specialist_pools_value"] -def test_get_data_labeling_job_from_dict(): - test_get_data_labeling_job(request_type=dict) - - def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2235,9 +2315,10 @@ async def test_get_data_labeling_job_flattened_error_async(): ) -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.ListDataLabelingJobsRequest, dict,] +) +def test_list_data_labeling_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2266,10 +2347,6 @@ def test_list_data_labeling_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_data_labeling_jobs_from_dict(): - test_list_data_labeling_jobs(request_type=dict) - - def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2454,8 +2531,10 @@ async def test_list_data_labeling_jobs_flattened_error_async(): ) -def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_labeling_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2500,8 +2579,10 @@ def test_list_data_labeling_jobs_pager(): assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) -def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_labeling_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2624,9 +2705,10 @@ async def test_list_data_labeling_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteDataLabelingJobRequest, dict,] +) +def test_delete_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2652,10 +2734,6 @@ def test_delete_data_labeling_job( assert isinstance(response, future.Future) -def test_delete_data_labeling_job_from_dict(): - test_delete_data_labeling_job(request_type=dict) - - def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2837,9 +2915,10 @@ async def test_delete_data_labeling_job_flattened_error_async(): ) -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelDataLabelingJobRequest, dict,] +) +def test_cancel_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2865,10 +2944,6 @@ def test_cancel_data_labeling_job( assert response is None -def test_cancel_data_labeling_job_from_dict(): - test_cancel_data_labeling_job(request_type=dict) - - def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3044,10 +3119,10 @@ async def test_cancel_data_labeling_job_flattened_error_async(): ) -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateHyperparameterTuningJobRequest, dict,] +) +def test_create_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3086,10 +3161,6 @@ def test_create_hyperparameter_tuning_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_hyperparameter_tuning_job_from_dict(): - test_create_hyperparameter_tuning_job(request_type=dict) - - def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3312,9 +3383,10 @@ async def test_create_hyperparameter_tuning_job_flattened_error_async(): ) -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.GetHyperparameterTuningJobRequest, dict,] +) +def test_get_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3353,10 +3425,6 @@ def test_get_hyperparameter_tuning_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_hyperparameter_tuning_job_from_dict(): - test_get_hyperparameter_tuning_job(request_type=dict) - - def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3551,10 +3619,10 @@ async def test_get_hyperparameter_tuning_job_flattened_error_async(): ) -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ListHyperparameterTuningJobsRequest, dict,] +) +def test_list_hyperparameter_tuning_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3583,10 +3651,6 @@ def test_list_hyperparameter_tuning_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_hyperparameter_tuning_jobs_from_dict(): - test_list_hyperparameter_tuning_jobs(request_type=dict) - - def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3771,8 +3835,10 @@ async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): ) -def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_hyperparameter_tuning_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3822,8 +3888,10 @@ def test_list_hyperparameter_tuning_jobs_pager(): ) -def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_hyperparameter_tuning_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3957,10 +4025,10 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteHyperparameterTuningJobRequest, dict,] +) +def test_delete_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3986,10 +4054,6 @@ def test_delete_hyperparameter_tuning_job( assert isinstance(response, future.Future) -def test_delete_hyperparameter_tuning_job_from_dict(): - test_delete_hyperparameter_tuning_job(request_type=dict) - - def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4171,10 +4235,10 @@ async def test_delete_hyperparameter_tuning_job_flattened_error_async(): ) -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelHyperparameterTuningJobRequest, dict,] +) +def test_cancel_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4200,10 +4264,6 @@ def test_cancel_hyperparameter_tuning_job( assert response is None -def test_cancel_hyperparameter_tuning_job_from_dict(): - test_cancel_hyperparameter_tuning_job(request_type=dict) - - def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4379,9 +4439,10 @@ async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): ) -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateBatchPredictionJobRequest, dict,] +) +def test_create_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4418,10 +4479,6 @@ def test_create_batch_prediction_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_batch_prediction_job_from_dict(): - test_create_batch_prediction_job(request_type=dict) - - def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4638,9 +4695,10 @@ async def test_create_batch_prediction_job_flattened_error_async(): ) -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.GetBatchPredictionJobRequest, dict,] +) +def test_get_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4677,10 +4735,6 @@ def test_get_batch_prediction_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_batch_prediction_job_from_dict(): - test_get_batch_prediction_job(request_type=dict) - - def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4873,9 +4927,10 @@ async def test_get_batch_prediction_job_flattened_error_async(): ) -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.ListBatchPredictionJobsRequest, dict,] +) +def test_list_batch_prediction_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4904,10 +4959,6 @@ def test_list_batch_prediction_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_batch_prediction_jobs_from_dict(): - test_list_batch_prediction_jobs(request_type=dict) - - def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5092,8 +5143,10 @@ async def test_list_batch_prediction_jobs_flattened_error_async(): ) -def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_batch_prediction_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -5140,8 +5193,10 @@ def test_list_batch_prediction_jobs_pager(): ) -def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_batch_prediction_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -5266,9 +5321,10 @@ async def test_list_batch_prediction_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteBatchPredictionJobRequest, dict,] +) +def test_delete_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5294,10 +5350,6 @@ def test_delete_batch_prediction_job( assert isinstance(response, future.Future) -def test_delete_batch_prediction_job_from_dict(): - test_delete_batch_prediction_job(request_type=dict) - - def test_delete_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5479,9 +5531,10 @@ async def test_delete_batch_prediction_job_flattened_error_async(): ) -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelBatchPredictionJobRequest, dict,] +) +def test_cancel_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5507,10 +5560,6 @@ def test_cancel_batch_prediction_job( assert response is None -def test_cancel_batch_prediction_job_from_dict(): - test_cancel_batch_prediction_job(request_type=dict) - - def test_cancel_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5686,10 +5735,10 @@ async def test_cancel_batch_prediction_job_flattened_error_async(): ) -def test_create_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.CreateModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateModelDeploymentMonitoringJobRequest, dict,] +) +def test_create_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5737,10 +5786,6 @@ def test_create_model_deployment_monitoring_job( assert response.enable_monitoring_pipeline_logs is True -def test_create_model_deployment_monitoring_job_from_dict(): - test_create_model_deployment_monitoring_job(request_type=dict) - - def test_create_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5978,9 +6023,12 @@ async def test_create_model_deployment_monitoring_job_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, dict,], +) def test_search_model_deployment_monitoring_stats_anomalies( - transport: str = "grpc", - request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + request_type, transport: str = "grpc" ): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6016,10 +6064,6 @@ def test_search_model_deployment_monitoring_stats_anomalies( assert response.next_page_token == "next_page_token_value" -def test_search_model_deployment_monitoring_stats_anomalies_from_dict(): - test_search_model_deployment_monitoring_stats_anomalies(request_type=dict) - - def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6250,8 +6294,12 @@ async def test_search_model_deployment_monitoring_stats_anomalies_flattened_erro ) -def test_search_model_deployment_monitoring_stats_anomalies_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_model_deployment_monitoring_stats_anomalies_pager( + transport_name: str = "grpc", +): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6306,8 +6354,12 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager(): ) -def test_search_model_deployment_monitoring_stats_anomalies_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_model_deployment_monitoring_stats_anomalies_pages( + transport_name: str = "grpc", +): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6448,10 +6500,10 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.GetModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.GetModelDeploymentMonitoringJobRequest, dict,] +) +def test_get_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6499,10 +6551,6 @@ def test_get_model_deployment_monitoring_job( assert response.enable_monitoring_pipeline_logs is True -def test_get_model_deployment_monitoring_job_from_dict(): - test_get_model_deployment_monitoring_job(request_type=dict) - - def test_get_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6712,10 +6760,10 @@ async def test_get_model_deployment_monitoring_job_flattened_error_async(): ) -def test_list_model_deployment_monitoring_jobs( - transport: str = "grpc", - request_type=job_service.ListModelDeploymentMonitoringJobsRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ListModelDeploymentMonitoringJobsRequest, dict,] +) +def test_list_model_deployment_monitoring_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6744,10 +6792,6 @@ def test_list_model_deployment_monitoring_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_model_deployment_monitoring_jobs_from_dict(): - test_list_model_deployment_monitoring_jobs(request_type=dict) - - def test_list_model_deployment_monitoring_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6936,8 +6980,10 @@ async def test_list_model_deployment_monitoring_jobs_flattened_error_async(): ) -def test_list_model_deployment_monitoring_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_deployment_monitoring_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6987,8 +7033,10 @@ def test_list_model_deployment_monitoring_jobs_pager(): ) -def test_list_model_deployment_monitoring_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_deployment_monitoring_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -7122,10 +7170,10 @@ async def test_list_model_deployment_monitoring_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.UpdateModelDeploymentMonitoringJobRequest, dict,] +) +def test_update_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7151,10 +7199,6 @@ def test_update_model_deployment_monitoring_job( assert isinstance(response, future.Future) -def test_update_model_deployment_monitoring_job_from_dict(): - test_update_model_deployment_monitoring_job(request_type=dict) - - def test_update_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7374,10 +7418,10 @@ async def test_update_model_deployment_monitoring_job_flattened_error_async(): ) -def test_delete_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteModelDeploymentMonitoringJobRequest, dict,] +) +def test_delete_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7403,10 +7447,6 @@ def test_delete_model_deployment_monitoring_job( assert isinstance(response, future.Future) -def test_delete_model_deployment_monitoring_job_from_dict(): - test_delete_model_deployment_monitoring_job(request_type=dict) - - def test_delete_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7590,10 +7630,10 @@ async def test_delete_model_deployment_monitoring_job_flattened_error_async(): ) -def test_pause_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.PauseModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.PauseModelDeploymentMonitoringJobRequest, dict,] +) +def test_pause_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7619,10 +7659,6 @@ def test_pause_model_deployment_monitoring_job( assert response is None -def test_pause_model_deployment_monitoring_job_from_dict(): - test_pause_model_deployment_monitoring_job(request_type=dict) - - def test_pause_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7800,10 +7836,10 @@ async def test_pause_model_deployment_monitoring_job_flattened_error_async(): ) -def test_resume_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ResumeModelDeploymentMonitoringJobRequest, dict,] +) +def test_resume_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7829,10 +7865,6 @@ def test_resume_model_deployment_monitoring_job( assert response is None -def test_resume_model_deployment_monitoring_job_from_dict(): - test_resume_model_deployment_monitoring_job(request_type=dict) - - def test_resume_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8030,6 +8062,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.JobServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -8817,7 +8866,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -8882,3 +8931,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (JobServiceClient, transports.JobServiceGrpcTransport), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py index 472a3f0949a..325acfdb02c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -269,20 +270,20 @@ def test_metadata_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -351,7 +352,7 @@ def test_metadata_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -428,6 +429,87 @@ def test_metadata_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient] +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) +def test_metadata_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -446,7 +528,7 @@ def test_metadata_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -460,24 +542,31 @@ def test_metadata_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_metadata_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -489,6 +578,35 @@ def test_metadata_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_metadata_service_client_client_options_from_dict(): with mock.patch( @@ -510,9 +628,10 @@ def test_metadata_service_client_client_options_from_dict(): ) -def test_create_metadata_store( - transport: str = "grpc", request_type=metadata_service.CreateMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateMetadataStoreRequest, dict,] +) +def test_create_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -538,10 +657,6 @@ def test_create_metadata_store( assert isinstance(response, future.Future) -def test_create_metadata_store_from_dict(): - test_create_metadata_store(request_type=dict) - - def test_create_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -755,9 +870,10 @@ async def test_create_metadata_store_flattened_error_async(): ) -def test_get_metadata_store( - transport: str = "grpc", request_type=metadata_service.GetMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.GetMetadataStoreRequest, dict,] +) +def test_get_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -787,10 +903,6 @@ def test_get_metadata_store( assert response.description == "description_value" -def test_get_metadata_store_from_dict(): - test_get_metadata_store(request_type=dict) - - def test_get_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -982,9 +1094,10 @@ async def test_get_metadata_store_flattened_error_async(): ) -def test_list_metadata_stores( - transport: str = "grpc", request_type=metadata_service.ListMetadataStoresRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListMetadataStoresRequest, dict,] +) +def test_list_metadata_stores(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1013,10 +1126,6 @@ def test_list_metadata_stores( assert response.next_page_token == "next_page_token_value" -def test_list_metadata_stores_from_dict(): - test_list_metadata_stores(request_type=dict) - - def test_list_metadata_stores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1207,8 +1316,10 @@ async def test_list_metadata_stores_flattened_error_async(): ) -def test_list_metadata_stores_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_stores_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1253,8 +1364,10 @@ def test_list_metadata_stores_pager(): assert all(isinstance(i, metadata_store.MetadataStore) for i in results) -def test_list_metadata_stores_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_stores_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1381,9 +1494,10 @@ async def test_list_metadata_stores_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_metadata_store( - transport: str = "grpc", request_type=metadata_service.DeleteMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteMetadataStoreRequest, dict,] +) +def test_delete_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1409,10 +1523,6 @@ def test_delete_metadata_store( assert isinstance(response, future.Future) -def test_delete_metadata_store_from_dict(): - test_delete_metadata_store(request_type=dict) - - def test_delete_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1600,9 +1710,10 @@ async def test_delete_metadata_store_flattened_error_async(): ) -def test_create_artifact( - transport: str = "grpc", request_type=metadata_service.CreateArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateArtifactRequest, dict,] +) +def test_create_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1643,10 +1754,6 @@ def test_create_artifact( assert response.description == "description_value" -def test_create_artifact_from_dict(): - test_create_artifact(request_type=dict) - - def test_create_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1864,9 +1971,8 @@ async def test_create_artifact_flattened_error_async(): ) -def test_get_artifact( - transport: str = "grpc", request_type=metadata_service.GetArtifactRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetArtifactRequest, dict,]) +def test_get_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1907,10 +2013,6 @@ def test_get_artifact( assert response.description == "description_value" -def test_get_artifact_from_dict(): - test_get_artifact(request_type=dict) - - def test_get_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2098,9 +2200,8 @@ async def test_get_artifact_flattened_error_async(): ) -def test_list_artifacts( - transport: str = "grpc", request_type=metadata_service.ListArtifactsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.ListArtifactsRequest, dict,]) +def test_list_artifacts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2127,10 +2228,6 @@ def test_list_artifacts( assert response.next_page_token == "next_page_token_value" -def test_list_artifacts_from_dict(): - test_list_artifacts(request_type=dict) - - def test_list_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2308,8 +2405,10 @@ async def test_list_artifacts_flattened_error_async(): ) -def test_list_artifacts_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_artifacts_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: @@ -2348,8 +2447,10 @@ def test_list_artifacts_pager(): assert all(isinstance(i, artifact.Artifact) for i in results) -def test_list_artifacts_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_artifacts_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: @@ -2458,9 +2559,10 @@ async def test_list_artifacts_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_artifact( - transport: str = "grpc", request_type=metadata_service.UpdateArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.UpdateArtifactRequest, dict,] +) +def test_update_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2501,10 +2603,6 @@ def test_update_artifact( assert response.description == "description_value" -def test_update_artifact_from_dict(): - test_update_artifact(request_type=dict) - - def test_update_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2716,9 +2814,10 @@ async def test_update_artifact_flattened_error_async(): ) -def test_delete_artifact( - transport: str = "grpc", request_type=metadata_service.DeleteArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteArtifactRequest, dict,] +) +def test_delete_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2742,10 +2841,6 @@ def test_delete_artifact( assert isinstance(response, future.Future) -def test_delete_artifact_from_dict(): - test_delete_artifact(request_type=dict) - - def test_delete_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2920,9 +3015,10 @@ async def test_delete_artifact_flattened_error_async(): ) -def test_purge_artifacts( - transport: str = "grpc", request_type=metadata_service.PurgeArtifactsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.PurgeArtifactsRequest, dict,] +) +def test_purge_artifacts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2946,10 +3042,6 @@ def test_purge_artifacts( assert isinstance(response, future.Future) -def test_purge_artifacts_from_dict(): - test_purge_artifacts(request_type=dict) - - def test_purge_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3124,9 +3216,8 @@ async def test_purge_artifacts_flattened_error_async(): ) -def test_create_context( - transport: str = "grpc", request_type=metadata_service.CreateContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.CreateContextRequest, dict,]) +def test_create_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3165,10 +3256,6 @@ def test_create_context( assert response.description == "description_value" -def test_create_context_from_dict(): - test_create_context(request_type=dict) - - def test_create_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3380,9 +3467,8 @@ async def test_create_context_flattened_error_async(): ) -def test_get_context( - transport: str = "grpc", request_type=metadata_service.GetContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetContextRequest, dict,]) +def test_get_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3421,10 +3507,6 @@ def test_get_context( assert response.description == "description_value" -def test_get_context_from_dict(): - test_get_context(request_type=dict) - - def test_get_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3610,9 +3692,8 @@ async def test_get_context_flattened_error_async(): ) -def test_list_contexts( - transport: str = "grpc", request_type=metadata_service.ListContextsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.ListContextsRequest, dict,]) +def test_list_contexts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3639,10 +3720,6 @@ def test_list_contexts( assert response.next_page_token == "next_page_token_value" -def test_list_contexts_from_dict(): - test_list_contexts(request_type=dict) - - def test_list_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3820,8 +3897,10 @@ async def test_list_contexts_flattened_error_async(): ) -def test_list_contexts_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_contexts_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: @@ -3854,8 +3933,10 @@ def test_list_contexts_pager(): assert all(isinstance(i, context.Context) for i in results) -def test_list_contexts_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_contexts_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: @@ -3946,9 +4027,8 @@ async def test_list_contexts_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_context( - transport: str = "grpc", request_type=metadata_service.UpdateContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.UpdateContextRequest, dict,]) +def test_update_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3987,10 +4067,6 @@ def test_update_context( assert response.description == "description_value" -def test_update_context_from_dict(): - test_update_context(request_type=dict) - - def test_update_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4196,9 +4272,8 @@ async def test_update_context_flattened_error_async(): ) -def test_delete_context( - transport: str = "grpc", request_type=metadata_service.DeleteContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.DeleteContextRequest, dict,]) +def test_delete_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4222,10 +4297,6 @@ def test_delete_context( assert isinstance(response, future.Future) -def test_delete_context_from_dict(): - test_delete_context(request_type=dict) - - def test_delete_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4400,9 +4471,8 @@ async def test_delete_context_flattened_error_async(): ) -def test_purge_contexts( - transport: str = "grpc", request_type=metadata_service.PurgeContextsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.PurgeContextsRequest, dict,]) +def test_purge_contexts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4426,10 +4496,6 @@ def test_purge_contexts( assert isinstance(response, future.Future) -def test_purge_contexts_from_dict(): - test_purge_contexts(request_type=dict) - - def test_purge_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4604,10 +4670,10 @@ async def test_purge_contexts_flattened_error_async(): ) -def test_add_context_artifacts_and_executions( - transport: str = "grpc", - request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddContextArtifactsAndExecutionsRequest, dict,] +) +def test_add_context_artifacts_and_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4635,10 +4701,6 @@ def test_add_context_artifacts_and_executions( ) -def test_add_context_artifacts_and_executions_from_dict(): - test_add_context_artifacts_and_executions(request_type=dict) - - def test_add_context_artifacts_and_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4854,9 +4916,10 @@ async def test_add_context_artifacts_and_executions_flattened_error_async(): ) -def test_add_context_children( - transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddContextChildrenRequest, dict,] +) +def test_add_context_children(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4882,10 +4945,6 @@ def test_add_context_children( assert isinstance(response, metadata_service.AddContextChildrenResponse) -def test_add_context_children_from_dict(): - test_add_context_children(request_type=dict) - - def test_add_context_children_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5087,10 +5146,10 @@ async def test_add_context_children_flattened_error_async(): ) -def test_query_context_lineage_subgraph( - transport: str = "grpc", - request_type=metadata_service.QueryContextLineageSubgraphRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryContextLineageSubgraphRequest, dict,] +) +def test_query_context_lineage_subgraph(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5116,10 +5175,6 @@ def test_query_context_lineage_subgraph( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_context_lineage_subgraph_from_dict(): - test_query_context_lineage_subgraph(request_type=dict) - - def test_query_context_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5309,9 +5364,10 @@ async def test_query_context_lineage_subgraph_flattened_error_async(): ) -def test_create_execution( - transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateExecutionRequest, dict,] +) +def test_create_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5350,10 +5406,6 @@ def test_create_execution( assert response.description == "description_value" -def test_create_execution_from_dict(): - test_create_execution(request_type=dict) - - def test_create_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5570,9 +5622,8 @@ async def test_create_execution_flattened_error_async(): ) -def test_get_execution( - transport: str = "grpc", request_type=metadata_service.GetExecutionRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetExecutionRequest, dict,]) +def test_get_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5611,10 +5662,6 @@ def test_get_execution( assert response.description == "description_value" -def test_get_execution_from_dict(): - test_get_execution(request_type=dict) - - def test_get_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5800,9 +5847,10 @@ async def test_get_execution_flattened_error_async(): ) -def test_list_executions( - transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListExecutionsRequest, dict,] +) +def test_list_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5829,10 +5877,6 @@ def test_list_executions( assert response.next_page_token == "next_page_token_value" -def test_list_executions_from_dict(): - test_list_executions(request_type=dict) - - def test_list_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6010,8 +6054,10 @@ async def test_list_executions_flattened_error_async(): ) -def test_list_executions_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_executions_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: @@ -6050,8 +6096,10 @@ def test_list_executions_pager(): assert all(isinstance(i, execution.Execution) for i in results) -def test_list_executions_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_executions_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: @@ -6160,9 +6208,10 @@ async def test_list_executions_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_execution( - transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.UpdateExecutionRequest, dict,] +) +def test_update_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6201,10 +6250,6 @@ def test_update_execution( assert response.description == "description_value" -def test_update_execution_from_dict(): - test_update_execution(request_type=dict) - - def test_update_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6415,9 +6460,10 @@ async def test_update_execution_flattened_error_async(): ) -def test_delete_execution( - transport: str = "grpc", request_type=metadata_service.DeleteExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteExecutionRequest, dict,] +) +def test_delete_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6441,10 +6487,6 @@ def test_delete_execution( assert isinstance(response, future.Future) -def test_delete_execution_from_dict(): - test_delete_execution(request_type=dict) - - def test_delete_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6620,9 +6662,10 @@ async def test_delete_execution_flattened_error_async(): ) -def test_purge_executions( - transport: str = "grpc", request_type=metadata_service.PurgeExecutionsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.PurgeExecutionsRequest, dict,] +) +def test_purge_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6646,10 +6689,6 @@ def test_purge_executions( assert isinstance(response, future.Future) -def test_purge_executions_from_dict(): - test_purge_executions(request_type=dict) - - def test_purge_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6825,9 +6864,10 @@ async def test_purge_executions_flattened_error_async(): ) -def test_add_execution_events( - transport: str = "grpc", request_type=metadata_service.AddExecutionEventsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddExecutionEventsRequest, dict,] +) +def test_add_execution_events(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6853,10 +6893,6 @@ def test_add_execution_events( assert isinstance(response, metadata_service.AddExecutionEventsResponse) -def test_add_execution_events_from_dict(): - test_add_execution_events(request_type=dict) - - def test_add_execution_events_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7060,10 +7096,10 @@ async def test_add_execution_events_flattened_error_async(): ) -def test_query_execution_inputs_and_outputs( - transport: str = "grpc", - request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryExecutionInputsAndOutputsRequest, dict,] +) +def test_query_execution_inputs_and_outputs(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7089,10 +7125,6 @@ def test_query_execution_inputs_and_outputs( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_execution_inputs_and_outputs_from_dict(): - test_query_execution_inputs_and_outputs(request_type=dict) - - def test_query_execution_inputs_and_outputs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7284,9 +7316,10 @@ async def test_query_execution_inputs_and_outputs_flattened_error_async(): ) -def test_create_metadata_schema( - transport: str = "grpc", request_type=metadata_service.CreateMetadataSchemaRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateMetadataSchemaRequest, dict,] +) +def test_create_metadata_schema(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7326,10 +7359,6 @@ def test_create_metadata_schema( assert response.description == "description_value" -def test_create_metadata_schema_from_dict(): - test_create_metadata_schema(request_type=dict) - - def test_create_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7557,9 +7586,10 @@ async def test_create_metadata_schema_flattened_error_async(): ) -def test_get_metadata_schema( - transport: str = "grpc", request_type=metadata_service.GetMetadataSchemaRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.GetMetadataSchemaRequest, dict,] +) +def test_get_metadata_schema(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7599,10 +7629,6 @@ def test_get_metadata_schema( assert response.description == "description_value" -def test_get_metadata_schema_from_dict(): - test_get_metadata_schema(request_type=dict) - - def test_get_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7804,9 +7830,10 @@ async def test_get_metadata_schema_flattened_error_async(): ) -def test_list_metadata_schemas( - transport: str = "grpc", request_type=metadata_service.ListMetadataSchemasRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListMetadataSchemasRequest, dict,] +) +def test_list_metadata_schemas(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7835,10 +7862,6 @@ def test_list_metadata_schemas( assert response.next_page_token == "next_page_token_value" -def test_list_metadata_schemas_from_dict(): - test_list_metadata_schemas(request_type=dict) - - def test_list_metadata_schemas_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8029,8 +8052,10 @@ async def test_list_metadata_schemas_flattened_error_async(): ) -def test_list_metadata_schemas_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_schemas_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8075,8 +8100,10 @@ def test_list_metadata_schemas_pager(): assert all(isinstance(i, metadata_schema.MetadataSchema) for i in results) -def test_list_metadata_schemas_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_schemas_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8203,10 +8230,10 @@ async def test_list_metadata_schemas_async_pages(): assert page_.raw_page.next_page_token == token -def test_query_artifact_lineage_subgraph( - transport: str = "grpc", - request_type=metadata_service.QueryArtifactLineageSubgraphRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryArtifactLineageSubgraphRequest, dict,] +) +def test_query_artifact_lineage_subgraph(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -8232,10 +8259,6 @@ def test_query_artifact_lineage_subgraph( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_artifact_lineage_subgraph_from_dict(): - test_query_artifact_lineage_subgraph(request_type=dict) - - def test_query_artifact_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8447,6 +8470,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MetadataServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MetadataServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.MetadataServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -9128,7 +9168,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -9193,3 +9233,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 15277d30725..0ab11cc1e3d 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -254,20 +255,20 @@ def test_migration_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -336,7 +337,7 @@ def test_migration_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -413,6 +414,87 @@ def test_migration_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -431,7 +513,7 @@ def test_migration_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -445,24 +527,31 @@ def test_migration_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_migration_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -474,6 +563,35 @@ def test_migration_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_migration_service_client_client_options_from_dict(): with mock.patch( @@ -495,10 +613,10 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources( - transport: str = "grpc", - request_type=migration_service.SearchMigratableResourcesRequest, -): +@pytest.mark.parametrize( + "request_type", [migration_service.SearchMigratableResourcesRequest, dict,] +) +def test_search_migratable_resources(request_type, transport: str = "grpc"): client = MigrationServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -527,10 +645,6 @@ def test_search_migratable_resources( assert response.next_page_token == "next_page_token_value" -def test_search_migratable_resources_from_dict(): - test_search_migratable_resources(request_type=dict) - - def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -721,8 +835,10 @@ async def test_search_migratable_resources_flattened_error_async(): ) -def test_search_migratable_resources_pager(): - client = MigrationServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_migratable_resources_pager(transport_name: str = "grpc"): + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -769,8 +885,10 @@ def test_search_migratable_resources_pager(): ) -def test_search_migratable_resources_pages(): - client = MigrationServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_migratable_resources_pages(transport_name: str = "grpc"): + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -899,9 +1017,10 @@ async def test_search_migratable_resources_async_pages(): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources( - transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest -): +@pytest.mark.parametrize( + "request_type", [migration_service.BatchMigrateResourcesRequest, dict,] +) +def test_batch_migrate_resources(request_type, transport: str = "grpc"): client = MigrationServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -927,10 +1046,6 @@ def test_batch_migrate_resources( assert isinstance(response, future.Future) -def test_batch_migrate_resources_from_dict(): - test_batch_migrate_resources(request_type=dict) - - def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1190,6 +1305,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MigrationServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.MigrationServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1648,18 +1780,20 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + location = "clam" + dataset = "whelk" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -1669,20 +1803,18 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + project = "cuttlefish" + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1860,7 +1992,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1925,3 +2057,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index d3252ec4ebf..3c848d8ff24 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -255,20 +256,20 @@ def test_model_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -325,7 +326,7 @@ def test_model_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -402,6 +403,83 @@ def test_model_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -420,7 +498,7 @@ def test_model_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -434,24 +512,31 @@ def test_model_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -463,6 +548,35 @@ def test_model_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_model_service_client_client_options_from_dict(): with mock.patch( @@ -482,9 +596,8 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.UploadModelRequest, dict,]) +def test_upload_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -508,10 +621,6 @@ def test_upload_model( assert isinstance(response, future.Future) -def test_upload_model_from_dict(): - test_upload_model(request_type=dict) - - def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -694,7 +803,8 @@ async def test_upload_model_flattened_error_async(): ) -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): +@pytest.mark.parametrize("request_type", [model_service.GetModelRequest, dict,]) +def test_get_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -747,10 +857,6 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR assert response.etag == "etag_value" -def test_get_model_from_dict(): - test_get_model(request_type=dict) - - def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -948,9 +1054,8 @@ async def test_get_model_flattened_error_async(): ) -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): +@pytest.mark.parametrize("request_type", [model_service.ListModelsRequest, dict,]) +def test_list_models(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -977,10 +1082,6 @@ def test_list_models( assert response.next_page_token == "next_page_token_value" -def test_list_models_from_dict(): - test_list_models(request_type=dict) - - def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1150,8 +1251,10 @@ async def test_list_models_flattened_error_async(): ) -def test_list_models_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: @@ -1182,8 +1285,10 @@ def test_list_models_pager(): assert all(isinstance(i, model.Model) for i in results) -def test_list_models_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: @@ -1264,9 +1369,8 @@ async def test_list_models_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.UpdateModelRequest, dict,]) +def test_update_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1319,10 +1423,6 @@ def test_update_model( assert response.etag == "etag_value" -def test_update_model_from_dict(): - test_update_model(request_type=dict) - - def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1536,9 +1636,8 @@ async def test_update_model_flattened_error_async(): ) -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.DeleteModelRequest, dict,]) +def test_delete_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1562,10 +1661,6 @@ def test_delete_model( assert isinstance(response, future.Future) -def test_delete_model_from_dict(): - test_delete_model(request_type=dict) - - def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1734,9 +1829,8 @@ async def test_delete_model_flattened_error_async(): ) -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.ExportModelRequest, dict,]) +def test_export_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1760,10 +1854,6 @@ def test_export_model( assert isinstance(response, future.Future) -def test_export_model_from_dict(): - test_export_model(request_type=dict) - - def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1960,9 +2050,10 @@ async def test_export_model_flattened_error_async(): ) -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.GetModelEvaluationRequest, dict,] +) +def test_get_model_evaluation(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1995,10 +2086,6 @@ def test_get_model_evaluation( assert response.slice_dimensions == ["slice_dimensions_value"] -def test_get_model_evaluation_from_dict(): - test_get_model_evaluation(request_type=dict) - - def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2187,9 +2274,10 @@ async def test_get_model_evaluation_flattened_error_async(): ) -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.ListModelEvaluationsRequest, dict,] +) +def test_list_model_evaluations(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2218,10 +2306,6 @@ def test_list_model_evaluations( assert response.next_page_token == "next_page_token_value" -def test_list_model_evaluations_from_dict(): - test_list_model_evaluations(request_type=dict) - - def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2406,8 +2490,10 @@ async def test_list_model_evaluations_flattened_error_async(): ) -def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluations_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2452,8 +2538,10 @@ def test_list_model_evaluations_pager(): assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) -def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluations_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2576,9 +2664,10 @@ async def test_list_model_evaluations_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.GetModelEvaluationSliceRequest, dict,] +) +def test_get_model_evaluation_slice(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2608,10 +2697,6 @@ def test_get_model_evaluation_slice( assert response.metrics_schema_uri == "metrics_schema_uri_value" -def test_get_model_evaluation_slice_from_dict(): - test_get_model_evaluation_slice(request_type=dict) - - def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2797,9 +2882,10 @@ async def test_get_model_evaluation_slice_flattened_error_async(): ) -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.ListModelEvaluationSlicesRequest, dict,] +) +def test_list_model_evaluation_slices(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2828,10 +2914,6 @@ def test_list_model_evaluation_slices( assert response.next_page_token == "next_page_token_value" -def test_list_model_evaluation_slices_from_dict(): - test_list_model_evaluation_slices(request_type=dict) - - def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3016,8 +3098,10 @@ async def test_list_model_evaluation_slices_flattened_error_async(): ) -def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluation_slices_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3066,8 +3150,10 @@ def test_list_model_evaluation_slices_pager(): ) -def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluation_slices_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3221,6 +3307,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.ModelServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3856,7 +3959,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3921,3 +4024,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 88895884892..65169e18689 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -274,20 +275,20 @@ def test_pipeline_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -356,7 +357,7 @@ def test_pipeline_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -433,6 +434,87 @@ def test_pipeline_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -451,7 +533,7 @@ def test_pipeline_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -465,24 +547,31 @@ def test_pipeline_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -494,6 +583,35 @@ def test_pipeline_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_pipeline_service_client_client_options_from_dict(): with mock.patch( @@ -515,9 +633,10 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CreateTrainingPipelineRequest, dict,] +) +def test_create_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -552,10 +671,6 @@ def test_create_training_pipeline( assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED -def test_create_training_pipeline_from_dict(): - test_create_training_pipeline(request_type=dict) - - def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -768,9 +883,10 @@ async def test_create_training_pipeline_flattened_error_async(): ) -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.GetTrainingPipelineRequest, dict,] +) +def test_get_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -805,10 +921,6 @@ def test_get_training_pipeline( assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED -def test_get_training_pipeline_from_dict(): - test_get_training_pipeline(request_type=dict) - - def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1005,9 +1117,10 @@ async def test_get_training_pipeline_flattened_error_async(): ) -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.ListTrainingPipelinesRequest, dict,] +) +def test_list_training_pipelines(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1036,10 +1149,6 @@ def test_list_training_pipelines( assert response.next_page_token == "next_page_token_value" -def test_list_training_pipelines_from_dict(): - test_list_training_pipelines(request_type=dict) - - def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1230,8 +1339,10 @@ async def test_list_training_pipelines_flattened_error_async(): ) -def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_training_pipelines_pager(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1276,8 +1387,10 @@ def test_list_training_pipelines_pager(): assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) -def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_training_pipelines_pages(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1404,9 +1517,10 @@ async def test_list_training_pipelines_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.DeleteTrainingPipelineRequest, dict,] +) +def test_delete_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1432,10 +1546,6 @@ def test_delete_training_pipeline( assert isinstance(response, future.Future) -def test_delete_training_pipeline_from_dict(): - test_delete_training_pipeline(request_type=dict) - - def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1623,9 +1733,10 @@ async def test_delete_training_pipeline_flattened_error_async(): ) -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CancelTrainingPipelineRequest, dict,] +) +def test_cancel_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1651,10 +1762,6 @@ def test_cancel_training_pipeline( assert response is None -def test_cancel_training_pipeline_from_dict(): - test_cancel_training_pipeline(request_type=dict) - - def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1836,9 +1943,10 @@ async def test_cancel_training_pipeline_flattened_error_async(): ) -def test_create_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.CreatePipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CreatePipelineJobRequest, dict,] +) +def test_create_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1875,10 +1983,6 @@ def test_create_pipeline_job( assert response.network == "network_value" -def test_create_pipeline_job_from_dict(): - test_create_pipeline_job(request_type=dict) - - def test_create_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2103,9 +2207,10 @@ async def test_create_pipeline_job_flattened_error_async(): ) -def test_get_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.GetPipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.GetPipelineJobRequest, dict,] +) +def test_get_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2140,10 +2245,6 @@ def test_get_pipeline_job( assert response.network == "network_value" -def test_get_pipeline_job_from_dict(): - test_get_pipeline_job(request_type=dict) - - def test_get_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2329,9 +2430,10 @@ async def test_get_pipeline_job_flattened_error_async(): ) -def test_list_pipeline_jobs( - transport: str = "grpc", request_type=pipeline_service.ListPipelineJobsRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.ListPipelineJobsRequest, dict,] +) +def test_list_pipeline_jobs(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2360,10 +2462,6 @@ def test_list_pipeline_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_pipeline_jobs_from_dict(): - test_list_pipeline_jobs(request_type=dict) - - def test_list_pipeline_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2554,8 +2652,10 @@ async def test_list_pipeline_jobs_flattened_error_async(): ) -def test_list_pipeline_jobs_pager(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_pipeline_jobs_pager(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2596,8 +2696,10 @@ def test_list_pipeline_jobs_pager(): assert all(isinstance(i, pipeline_job.PipelineJob) for i in results) -def test_list_pipeline_jobs_pages(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_pipeline_jobs_pages(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2712,9 +2814,10 @@ async def test_list_pipeline_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.DeletePipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.DeletePipelineJobRequest, dict,] +) +def test_delete_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2740,10 +2843,6 @@ def test_delete_pipeline_job( assert isinstance(response, future.Future) -def test_delete_pipeline_job_from_dict(): - test_delete_pipeline_job(request_type=dict) - - def test_delete_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2931,9 +3030,10 @@ async def test_delete_pipeline_job_flattened_error_async(): ) -def test_cancel_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.CancelPipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CancelPipelineJobRequest, dict,] +) +def test_cancel_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2959,10 +3059,6 @@ def test_cancel_pipeline_job( assert response is None -def test_cancel_pipeline_job_from_dict(): - test_cancel_pipeline_job(request_type=dict) - - def test_cancel_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3164,6 +3260,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PipelineServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.PipelineServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3911,7 +4024,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3976,3 +4089,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py index 8278921b195..bc3bf66974e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py @@ -254,20 +254,20 @@ def test_prediction_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -336,7 +336,7 @@ def test_prediction_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -413,6 +413,87 @@ def test_prediction_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +def test_prediction_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -431,7 +512,7 @@ def test_prediction_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -445,24 +526,31 @@ def test_prediction_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_prediction_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -474,6 +562,35 @@ def test_prediction_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_prediction_service_client_client_options_from_dict(): with mock.patch( @@ -495,9 +612,8 @@ def test_prediction_service_client_client_options_from_dict(): ) -def test_predict( - transport: str = "grpc", request_type=prediction_service.PredictRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.PredictRequest, dict,]) +def test_predict(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -528,10 +644,6 @@ def test_predict( assert response.model_display_name == "model_display_name_value" -def test_predict_from_dict(): - test_predict(request_type=dict) - - def test_predict_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -672,9 +784,8 @@ async def test_predict_flattened_error_async(): ) -def test_raw_predict( - transport: str = "grpc", request_type=prediction_service.RawPredictRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.RawPredictRequest, dict,]) +def test_raw_predict(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -702,10 +813,6 @@ def test_raw_predict( assert response.data == b"data_blob" -def test_raw_predict_from_dict(): - test_raw_predict(request_type=dict) - - def test_raw_predict_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -898,9 +1005,8 @@ async def test_raw_predict_flattened_error_async(): ) -def test_explain( - transport: str = "grpc", request_type=prediction_service.ExplainRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.ExplainRequest, dict,]) +def test_explain(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -927,10 +1033,6 @@ def test_explain( assert response.deployed_model_id == "deployed_model_id_value" -def test_explain_from_dict(): - test_explain(request_type=dict) - - def test_explain_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1089,6 +1191,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.PredictionServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1609,7 +1728,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1674,3 +1793,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index 4bf34049391..22f9df32409 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -263,20 +264,20 @@ def test_specialist_pool_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -345,7 +346,7 @@ def test_specialist_pool_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -422,6 +423,87 @@ def test_specialist_pool_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -444,7 +526,7 @@ def test_specialist_pool_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -458,28 +540,31 @@ def test_specialist_pool_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", + grpc_helpers, ), ( SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -491,6 +576,35 @@ def test_specialist_pool_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_specialist_pool_service_client_client_options_from_dict(): with mock.patch( @@ -512,10 +626,10 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.CreateSpecialistPoolRequest, dict,] +) +def test_create_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -541,10 +655,6 @@ def test_create_specialist_pool( assert isinstance(response, future.Future) -def test_create_specialist_pool_from_dict(): - test_create_specialist_pool(request_type=dict) - - def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -754,10 +864,10 @@ async def test_create_specialist_pool_flattened_error_async(): ) -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.GetSpecialistPoolRequest, dict,] +) +def test_get_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -796,10 +906,6 @@ def test_get_specialist_pool( assert response.specialist_worker_emails == ["specialist_worker_emails_value"] -def test_get_specialist_pool_from_dict(): - test_get_specialist_pool(request_type=dict) - - def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1006,10 +1112,10 @@ async def test_get_specialist_pool_flattened_error_async(): ) -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.ListSpecialistPoolsRequest, dict,] +) +def test_list_specialist_pools(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1038,10 +1144,6 @@ def test_list_specialist_pools( assert response.next_page_token == "next_page_token_value" -def test_list_specialist_pools_from_dict(): - test_list_specialist_pools(request_type=dict) - - def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1238,9 +1340,9 @@ async def test_list_specialist_pools_flattened_error_async(): ) -def test_list_specialist_pools_pager(): +def test_list_specialist_pools_pager(transport_name: str = "grpc"): client = SpecialistPoolServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1286,9 +1388,9 @@ def test_list_specialist_pools_pager(): assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) -def test_list_specialist_pools_pages(): +def test_list_specialist_pools_pages(transport_name: str = "grpc"): client = SpecialistPoolServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1416,10 +1518,10 @@ async def test_list_specialist_pools_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.DeleteSpecialistPoolRequest, dict,] +) +def test_delete_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1445,10 +1547,6 @@ def test_delete_specialist_pool( assert isinstance(response, future.Future) -def test_delete_specialist_pool_from_dict(): - test_delete_specialist_pool(request_type=dict) - - def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1642,10 +1740,10 @@ async def test_delete_specialist_pool_flattened_error_async(): ) -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.UpdateSpecialistPoolRequest, dict,] +) +def test_update_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1671,10 +1769,6 @@ def test_update_specialist_pool( assert isinstance(response, future.Future) -def test_update_specialist_pool_from_dict(): - test_update_specialist_pool(request_type=dict) - - def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1910,6 +2004,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.SpecialistPoolServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2447,7 +2560,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2512,3 +2625,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py index 0a5b116fe3c..27782cd2d9e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -272,20 +273,20 @@ def test_tensorboard_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -354,7 +355,7 @@ def test_tensorboard_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -431,6 +432,87 @@ def test_tensorboard_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [TensorboardServiceClient, TensorboardServiceAsyncClient] +) +@mock.patch.object( + TensorboardServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceClient), +) +@mock.patch.object( + TensorboardServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceAsyncClient), +) +def test_tensorboard_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -449,7 +531,7 @@ def test_tensorboard_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -463,24 +545,31 @@ def test_tensorboard_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport, "grpc"), + ( + TensorboardServiceClient, + transports.TensorboardServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( TensorboardServiceAsyncClient, transports.TensorboardServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_tensorboard_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -492,6 +581,38 @@ def test_tensorboard_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_tensorboard_service_client_client_options_from_dict(): with mock.patch( @@ -513,9 +634,10 @@ def test_tensorboard_service_client_client_options_from_dict(): ) -def test_create_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.CreateTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardRequest, dict,] +) +def test_create_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -541,10 +663,6 @@ def test_create_tensorboard( assert isinstance(response, future.Future) -def test_create_tensorboard_from_dict(): - test_create_tensorboard(request_type=dict) - - def test_create_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -754,9 +872,10 @@ async def test_create_tensorboard_flattened_error_async(): ) -def test_get_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardRequest, dict,] +) +def test_get_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -793,10 +912,6 @@ def test_get_tensorboard( assert response.etag == "etag_value" -def test_get_tensorboard_from_dict(): - test_get_tensorboard(request_type=dict) - - def test_get_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -991,9 +1106,10 @@ async def test_get_tensorboard_flattened_error_async(): ) -def test_update_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.UpdateTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardRequest, dict,] +) +def test_update_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1019,10 +1135,6 @@ def test_update_tensorboard( assert isinstance(response, future.Future) -def test_update_tensorboard_from_dict(): - test_update_tensorboard(request_type=dict) - - def test_update_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1236,9 +1348,10 @@ async def test_update_tensorboard_flattened_error_async(): ) -def test_list_tensorboards( - transport: str = "grpc", request_type=tensorboard_service.ListTensorboardsRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardsRequest, dict,] +) +def test_list_tensorboards(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1267,10 +1380,6 @@ def test_list_tensorboards( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboards_from_dict(): - test_list_tensorboards(request_type=dict) - - def test_list_tensorboards_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1467,8 +1576,10 @@ async def test_list_tensorboards_flattened_error_async(): ) -def test_list_tensorboards_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboards_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1509,8 +1620,10 @@ def test_list_tensorboards_pager(): assert all(isinstance(i, tensorboard.Tensorboard) for i in results) -def test_list_tensorboards_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboards_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1625,9 +1738,10 @@ async def test_list_tensorboards_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.DeleteTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardRequest, dict,] +) +def test_delete_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1653,10 +1767,6 @@ def test_delete_tensorboard( assert isinstance(response, future.Future) -def test_delete_tensorboard_from_dict(): - test_delete_tensorboard(request_type=dict) - - def test_delete_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1850,10 +1960,10 @@ async def test_delete_tensorboard_flattened_error_async(): ) -def test_create_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardExperimentRequest, dict,] +) +def test_create_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1890,10 +2000,6 @@ def test_create_tensorboard_experiment( assert response.source == "source_value" -def test_create_tensorboard_experiment_from_dict(): - test_create_tensorboard_experiment(request_type=dict) - - def test_create_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2132,10 +2238,10 @@ async def test_create_tensorboard_experiment_flattened_error_async(): ) -def test_get_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.GetTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardExperimentRequest, dict,] +) +def test_get_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2172,10 +2278,6 @@ def test_get_tensorboard_experiment( assert response.source == "source_value" -def test_get_tensorboard_experiment_from_dict(): - test_get_tensorboard_experiment(request_type=dict) - - def test_get_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2380,10 +2482,10 @@ async def test_get_tensorboard_experiment_flattened_error_async(): ) -def test_update_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardExperimentRequest, dict,] +) +def test_update_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2420,10 +2522,6 @@ def test_update_tensorboard_experiment( assert response.source == "source_value" -def test_update_tensorboard_experiment_from_dict(): - test_update_tensorboard_experiment(request_type=dict) - - def test_update_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2658,10 +2756,10 @@ async def test_update_tensorboard_experiment_flattened_error_async(): ) -def test_list_tensorboard_experiments( - transport: str = "grpc", - request_type=tensorboard_service.ListTensorboardExperimentsRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardExperimentsRequest, dict,] +) +def test_list_tensorboard_experiments(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2690,10 +2788,6 @@ def test_list_tensorboard_experiments( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_experiments_from_dict(): - test_list_tensorboard_experiments(request_type=dict) - - def test_list_tensorboard_experiments_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2892,8 +2986,10 @@ async def test_list_tensorboard_experiments_flattened_error_async(): ) -def test_list_tensorboard_experiments_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_experiments_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2942,8 +3038,10 @@ def test_list_tensorboard_experiments_pager(): ) -def test_list_tensorboard_experiments_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_experiments_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3081,10 +3179,10 @@ async def test_list_tensorboard_experiments_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardExperimentRequest, dict,] +) +def test_delete_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3110,10 +3208,6 @@ def test_delete_tensorboard_experiment( assert isinstance(response, future.Future) -def test_delete_tensorboard_experiment_from_dict(): - test_delete_tensorboard_experiment(request_type=dict) - - def test_delete_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3307,10 +3401,10 @@ async def test_delete_tensorboard_experiment_flattened_error_async(): ) -def test_create_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardRunRequest, dict,] +) +def test_create_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3345,10 +3439,6 @@ def test_create_tensorboard_run( assert response.etag == "etag_value" -def test_create_tensorboard_run_from_dict(): - test_create_tensorboard_run(request_type=dict) - - def test_create_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3577,10 +3667,10 @@ async def test_create_tensorboard_run_flattened_error_async(): ) -def test_batch_create_tensorboard_runs( - transport: str = "grpc", - request_type=tensorboard_service.BatchCreateTensorboardRunsRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.BatchCreateTensorboardRunsRequest, dict,] +) +def test_batch_create_tensorboard_runs(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3606,10 +3696,6 @@ def test_batch_create_tensorboard_runs( assert isinstance(response, tensorboard_service.BatchCreateTensorboardRunsResponse) -def test_batch_create_tensorboard_runs_from_dict(): - test_batch_create_tensorboard_runs(request_type=dict) - - def test_batch_create_tensorboard_runs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3831,9 +3917,10 @@ async def test_batch_create_tensorboard_runs_flattened_error_async(): ) -def test_get_tensorboard_run( - transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRunRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardRunRequest, dict,] +) +def test_get_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3868,10 +3955,6 @@ def test_get_tensorboard_run( assert response.etag == "etag_value" -def test_get_tensorboard_run_from_dict(): - test_get_tensorboard_run(request_type=dict) - - def test_get_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4074,10 +4157,10 @@ async def test_get_tensorboard_run_flattened_error_async(): ) -def test_update_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardRunRequest, dict,] +) +def test_update_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4112,10 +4195,6 @@ def test_update_tensorboard_run( assert response.etag == "etag_value" -def test_update_tensorboard_run_from_dict(): - test_update_tensorboard_run(request_type=dict) - - def test_update_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4340,9 +4419,10 @@ async def test_update_tensorboard_run_flattened_error_async(): ) -def test_list_tensorboard_runs( - transport: str = "grpc", request_type=tensorboard_service.ListTensorboardRunsRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardRunsRequest, dict,] +) +def test_list_tensorboard_runs(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4371,10 +4451,6 @@ def test_list_tensorboard_runs( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_runs_from_dict(): - test_list_tensorboard_runs(request_type=dict) - - def test_list_tensorboard_runs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4571,8 +4647,10 @@ async def test_list_tensorboard_runs_flattened_error_async(): ) -def test_list_tensorboard_runs_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_runs_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -4617,8 +4695,10 @@ def test_list_tensorboard_runs_pager(): assert all(isinstance(i, tensorboard_run.TensorboardRun) for i in results) -def test_list_tensorboard_runs_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_runs_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -4745,10 +4825,10 @@ async def test_list_tensorboard_runs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardRunRequest, dict,] +) +def test_delete_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4774,10 +4854,6 @@ def test_delete_tensorboard_run( assert isinstance(response, future.Future) -def test_delete_tensorboard_run_from_dict(): - test_delete_tensorboard_run(request_type=dict) - - def test_delete_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4971,10 +5047,10 @@ async def test_delete_tensorboard_run_flattened_error_async(): ) -def test_batch_create_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, dict,] +) +def test_batch_create_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5004,10 +5080,6 @@ def test_batch_create_tensorboard_time_series( ) -def test_batch_create_tensorboard_time_series_from_dict(): - test_batch_create_tensorboard_time_series(request_type=dict) - - def test_batch_create_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5249,10 +5321,10 @@ async def test_batch_create_tensorboard_time_series_flattened_error_async(): ) -def test_create_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardTimeSeriesRequest, dict,] +) +def test_create_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5296,10 +5368,6 @@ def test_create_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_create_tensorboard_time_series_from_dict(): - test_create_tensorboard_time_series(request_type=dict) - - def test_create_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5535,10 +5603,10 @@ async def test_create_tensorboard_time_series_flattened_error_async(): ) -def test_get_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardTimeSeriesRequest, dict,] +) +def test_get_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5582,10 +5650,6 @@ def test_get_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_get_tensorboard_time_series_from_dict(): - test_get_tensorboard_time_series(request_type=dict) - - def test_get_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5797,10 +5861,10 @@ async def test_get_tensorboard_time_series_flattened_error_async(): ) -def test_update_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardTimeSeriesRequest, dict,] +) +def test_update_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5844,10 +5908,6 @@ def test_update_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_update_tensorboard_time_series_from_dict(): - test_update_tensorboard_time_series(request_type=dict) - - def test_update_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6089,10 +6149,10 @@ async def test_update_tensorboard_time_series_flattened_error_async(): ) -def test_list_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardTimeSeriesRequest, dict,] +) +def test_list_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6121,10 +6181,6 @@ def test_list_tensorboard_time_series( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_time_series_from_dict(): - test_list_tensorboard_time_series(request_type=dict) - - def test_list_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6323,8 +6379,10 @@ async def test_list_tensorboard_time_series_flattened_error_async(): ) -def test_list_tensorboard_time_series_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_time_series_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6374,8 +6432,10 @@ def test_list_tensorboard_time_series_pager(): ) -def test_list_tensorboard_time_series_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_time_series_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6513,10 +6573,10 @@ async def test_list_tensorboard_time_series_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardTimeSeriesRequest, dict,] +) +def test_delete_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6542,10 +6602,6 @@ def test_delete_tensorboard_time_series( assert isinstance(response, future.Future) -def test_delete_tensorboard_time_series_from_dict(): - test_delete_tensorboard_time_series(request_type=dict) - - def test_delete_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6739,10 +6795,11 @@ async def test_delete_tensorboard_time_series_flattened_error_async(): ) -def test_batch_read_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", + [tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, dict,], +) +def test_batch_read_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6774,10 +6831,6 @@ def test_batch_read_tensorboard_time_series_data( ) -def test_batch_read_tensorboard_time_series_data_from_dict(): - test_batch_read_tensorboard_time_series_data(request_type=dict) - - def test_batch_read_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6987,10 +7040,10 @@ async def test_batch_read_tensorboard_time_series_data_flattened_error_async(): ) -def test_read_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ReadTensorboardTimeSeriesDataRequest, dict,] +) +def test_read_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7018,10 +7071,6 @@ def test_read_tensorboard_time_series_data( ) -def test_read_tensorboard_time_series_data_from_dict(): - test_read_tensorboard_time_series_data(request_type=dict) - - def test_read_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7229,10 +7278,10 @@ async def test_read_tensorboard_time_series_data_flattened_error_async(): ) -def test_read_tensorboard_blob_data( - transport: str = "grpc", - request_type=tensorboard_service.ReadTensorboardBlobDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ReadTensorboardBlobDataRequest, dict,] +) +def test_read_tensorboard_blob_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7261,10 +7310,6 @@ def test_read_tensorboard_blob_data( assert isinstance(message, tensorboard_service.ReadTensorboardBlobDataResponse) -def test_read_tensorboard_blob_data_from_dict(): - test_read_tensorboard_blob_data(request_type=dict) - - def test_read_tensorboard_blob_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7469,10 +7514,10 @@ async def test_read_tensorboard_blob_data_flattened_error_async(): ) -def test_write_tensorboard_experiment_data( - transport: str = "grpc", - request_type=tensorboard_service.WriteTensorboardExperimentDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.WriteTensorboardExperimentDataRequest, dict,] +) +def test_write_tensorboard_experiment_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7500,10 +7545,6 @@ def test_write_tensorboard_experiment_data( ) -def test_write_tensorboard_experiment_data_from_dict(): - test_write_tensorboard_experiment_data(request_type=dict) - - def test_write_tensorboard_experiment_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7745,10 +7786,10 @@ async def test_write_tensorboard_experiment_data_flattened_error_async(): ) -def test_write_tensorboard_run_data( - transport: str = "grpc", - request_type=tensorboard_service.WriteTensorboardRunDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.WriteTensorboardRunDataRequest, dict,] +) +def test_write_tensorboard_run_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7774,10 +7815,6 @@ def test_write_tensorboard_run_data( assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) -def test_write_tensorboard_run_data_from_dict(): - test_write_tensorboard_run_data(request_type=dict) - - def test_write_tensorboard_run_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8015,10 +8052,10 @@ async def test_write_tensorboard_run_data_flattened_error_async(): ) -def test_export_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ExportTensorboardTimeSeriesDataRequest, dict,] +) +def test_export_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -8047,10 +8084,6 @@ def test_export_tensorboard_time_series_data( assert response.next_page_token == "next_page_token_value" -def test_export_tensorboard_time_series_data_from_dict(): - test_export_tensorboard_time_series_data(request_type=dict) - - def test_export_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8265,8 +8298,10 @@ async def test_export_tensorboard_time_series_data_flattened_error_async(): ) -def test_export_tensorboard_time_series_data_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_export_tensorboard_time_series_data_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8313,8 +8348,10 @@ def test_export_tensorboard_time_series_data_pager(): assert all(isinstance(i, tensorboard_data.TimeSeriesDataPoint) for i in results) -def test_export_tensorboard_time_series_data_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_export_tensorboard_time_series_data_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8465,6 +8502,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TensorboardServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TensorboardServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.TensorboardServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -9135,7 +9189,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -9200,3 +9254,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py index 5050a308071..b8fca485def 100644 --- a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -255,20 +256,20 @@ def test_vizier_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -327,7 +328,7 @@ def test_vizier_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -404,6 +405,87 @@ def test_vizier_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient] +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) +def test_vizier_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -422,7 +504,7 @@ def test_vizier_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -436,24 +518,31 @@ def test_vizier_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceClient, + transports.VizierServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_vizier_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -465,6 +554,35 @@ def test_vizier_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_vizier_service_client_client_options_from_dict(): with mock.patch( @@ -486,9 +604,8 @@ def test_vizier_service_client_client_options_from_dict(): ) -def test_create_study( - transport: str = "grpc", request_type=vizier_service.CreateStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CreateStudyRequest, dict,]) +def test_create_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -521,10 +638,6 @@ def test_create_study( assert response.inactive_reason == "inactive_reason_value" -def test_create_study_from_dict(): - test_create_study(request_type=dict) - - def test_create_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -718,9 +831,8 @@ async def test_create_study_flattened_error_async(): ) -def test_get_study( - transport: str = "grpc", request_type=vizier_service.GetStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.GetStudyRequest, dict,]) +def test_get_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -753,10 +865,6 @@ def test_get_study( assert response.inactive_reason == "inactive_reason_value" -def test_get_study_from_dict(): - test_get_study(request_type=dict) - - def test_get_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -936,9 +1044,8 @@ async def test_get_study_flattened_error_async(): ) -def test_list_studies( - transport: str = "grpc", request_type=vizier_service.ListStudiesRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.ListStudiesRequest, dict,]) +def test_list_studies(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -965,10 +1072,6 @@ def test_list_studies( assert response.next_page_token == "next_page_token_value" -def test_list_studies_from_dict(): - test_list_studies(request_type=dict) - - def test_list_studies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1144,8 +1247,10 @@ async def test_list_studies_flattened_error_async(): ) -def test_list_studies_pager(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_studies_pager(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: @@ -1178,8 +1283,10 @@ def test_list_studies_pager(): assert all(isinstance(i, study.Study) for i in results) -def test_list_studies_pages(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_studies_pages(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: @@ -1266,9 +1373,8 @@ async def test_list_studies_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_study( - transport: str = "grpc", request_type=vizier_service.DeleteStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.DeleteStudyRequest, dict,]) +def test_delete_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1292,10 +1398,6 @@ def test_delete_study( assert response is None -def test_delete_study_from_dict(): - test_delete_study(request_type=dict) - - def test_delete_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1464,9 +1566,8 @@ async def test_delete_study_flattened_error_async(): ) -def test_lookup_study( - transport: str = "grpc", request_type=vizier_service.LookupStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.LookupStudyRequest, dict,]) +def test_lookup_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1499,10 +1600,6 @@ def test_lookup_study( assert response.inactive_reason == "inactive_reason_value" -def test_lookup_study_from_dict(): - test_lookup_study(request_type=dict) - - def test_lookup_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1682,9 +1779,8 @@ async def test_lookup_study_flattened_error_async(): ) -def test_suggest_trials( - transport: str = "grpc", request_type=vizier_service.SuggestTrialsRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.SuggestTrialsRequest, dict,]) +def test_suggest_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1708,10 +1804,6 @@ def test_suggest_trials( assert isinstance(response, future.Future) -def test_suggest_trials_from_dict(): - test_suggest_trials(request_type=dict) - - def test_suggest_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1814,9 +1906,8 @@ async def test_suggest_trials_field_headers_async(): assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_create_trial( - transport: str = "grpc", request_type=vizier_service.CreateTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CreateTrialRequest, dict,]) +def test_create_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1853,10 +1944,6 @@ def test_create_trial( assert response.custom_job == "custom_job_value" -def test_create_trial_from_dict(): - test_create_trial(request_type=dict) - - def test_create_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2054,9 +2141,8 @@ async def test_create_trial_flattened_error_async(): ) -def test_get_trial( - transport: str = "grpc", request_type=vizier_service.GetTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.GetTrialRequest, dict,]) +def test_get_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2093,10 +2179,6 @@ def test_get_trial( assert response.custom_job == "custom_job_value" -def test_get_trial_from_dict(): - test_get_trial(request_type=dict) - - def test_get_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2280,9 +2362,8 @@ async def test_get_trial_flattened_error_async(): ) -def test_list_trials( - transport: str = "grpc", request_type=vizier_service.ListTrialsRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.ListTrialsRequest, dict,]) +def test_list_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2309,10 +2390,6 @@ def test_list_trials( assert response.next_page_token == "next_page_token_value" -def test_list_trials_from_dict(): - test_list_trials(request_type=dict) - - def test_list_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2488,8 +2565,10 @@ async def test_list_trials_flattened_error_async(): ) -def test_list_trials_pager(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_trials_pager(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: @@ -2520,8 +2599,10 @@ def test_list_trials_pager(): assert all(isinstance(i, study.Trial) for i in results) -def test_list_trials_pages(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_trials_pages(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: @@ -2602,9 +2683,10 @@ async def test_list_trials_async_pages(): assert page_.raw_page.next_page_token == token -def test_add_trial_measurement( - transport: str = "grpc", request_type=vizier_service.AddTrialMeasurementRequest -): +@pytest.mark.parametrize( + "request_type", [vizier_service.AddTrialMeasurementRequest, dict,] +) +def test_add_trial_measurement(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2643,10 +2725,6 @@ def test_add_trial_measurement( assert response.custom_job == "custom_job_value" -def test_add_trial_measurement_from_dict(): - test_add_trial_measurement(request_type=dict) - - def test_add_trial_measurement_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2769,9 +2847,8 @@ async def test_add_trial_measurement_field_headers_async(): assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_complete_trial( - transport: str = "grpc", request_type=vizier_service.CompleteTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CompleteTrialRequest, dict,]) +def test_complete_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2808,10 +2885,6 @@ def test_complete_trial( assert response.custom_job == "custom_job_value" -def test_complete_trial_from_dict(): - test_complete_trial(request_type=dict) - - def test_complete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2925,9 +2998,8 @@ async def test_complete_trial_field_headers_async(): assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_delete_trial( - transport: str = "grpc", request_type=vizier_service.DeleteTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.DeleteTrialRequest, dict,]) +def test_delete_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2951,10 +3023,6 @@ def test_delete_trial( assert response is None -def test_delete_trial_from_dict(): - test_delete_trial(request_type=dict) - - def test_delete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3123,10 +3191,10 @@ async def test_delete_trial_flattened_error_async(): ) -def test_check_trial_early_stopping_state( - transport: str = "grpc", - request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, -): +@pytest.mark.parametrize( + "request_type", [vizier_service.CheckTrialEarlyStoppingStateRequest, dict,] +) +def test_check_trial_early_stopping_state(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3152,10 +3220,6 @@ def test_check_trial_early_stopping_state( assert isinstance(response, future.Future) -def test_check_trial_early_stopping_state_from_dict(): - test_check_trial_early_stopping_state(request_type=dict) - - def test_check_trial_early_stopping_state_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3267,9 +3331,8 @@ async def test_check_trial_early_stopping_state_field_headers_async(): assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_stop_trial( - transport: str = "grpc", request_type=vizier_service.StopTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.StopTrialRequest, dict,]) +def test_stop_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3306,10 +3369,6 @@ def test_stop_trial( assert response.custom_job == "custom_job_value" -def test_stop_trial_from_dict(): - test_stop_trial(request_type=dict) - - def test_stop_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3423,9 +3482,10 @@ async def test_stop_trial_field_headers_async(): assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_list_optimal_trials( - transport: str = "grpc", request_type=vizier_service.ListOptimalTrialsRequest -): +@pytest.mark.parametrize( + "request_type", [vizier_service.ListOptimalTrialsRequest, dict,] +) +def test_list_optimal_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3451,10 +3511,6 @@ def test_list_optimal_trials( assert isinstance(response, vizier_service.ListOptimalTrialsResponse) -def test_list_optimal_trials_from_dict(): - test_list_optimal_trials(request_type=dict) - - def test_list_optimal_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3662,6 +3718,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = VizierServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = VizierServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.VizierServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -4249,7 +4322,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -4314,3 +4387,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index b3c85785078..6951f7cb3f9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -265,20 +266,20 @@ def test_dataset_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -337,7 +338,7 @@ def test_dataset_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -414,6 +415,87 @@ def test_dataset_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -432,7 +514,7 @@ def test_dataset_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -446,24 +528,31 @@ def test_dataset_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceClient, + transports.DatasetServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -475,6 +564,35 @@ def test_dataset_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_dataset_service_client_client_options_from_dict(): with mock.patch( @@ -496,9 +614,8 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.CreateDatasetRequest, dict,]) +def test_create_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -522,10 +639,6 @@ def test_create_dataset( assert isinstance(response, future.Future) -def test_create_dataset_from_dict(): - test_create_dataset(request_type=dict) - - def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -714,9 +827,8 @@ async def test_create_dataset_flattened_error_async(): ) -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.GetDatasetRequest, dict,]) +def test_get_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -751,10 +863,6 @@ def test_get_dataset( assert response.etag == "etag_value" -def test_get_dataset_from_dict(): - test_get_dataset(request_type=dict) - - def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -936,9 +1044,8 @@ async def test_get_dataset_flattened_error_async(): ) -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.UpdateDatasetRequest, dict,]) +def test_update_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -973,10 +1080,6 @@ def test_update_dataset( assert response.etag == "etag_value" -def test_update_dataset_from_dict(): - test_update_dataset(request_type=dict) - - def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1178,9 +1281,8 @@ async def test_update_dataset_flattened_error_async(): ) -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ListDatasetsRequest, dict,]) +def test_list_datasets(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1207,10 +1309,6 @@ def test_list_datasets( assert response.next_page_token == "next_page_token_value" -def test_list_datasets_from_dict(): - test_list_datasets(request_type=dict) - - def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1388,8 +1486,10 @@ async def test_list_datasets_flattened_error_async(): ) -def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_datasets_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: @@ -1422,8 +1522,10 @@ def test_list_datasets_pager(): assert all(isinstance(i, dataset.Dataset) for i in results) -def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_datasets_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: @@ -1510,9 +1612,8 @@ async def test_list_datasets_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.DeleteDatasetRequest, dict,]) +def test_delete_dataset(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1536,10 +1637,6 @@ def test_delete_dataset( assert isinstance(response, future.Future) -def test_delete_dataset_from_dict(): - test_delete_dataset(request_type=dict) - - def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1714,9 +1811,8 @@ async def test_delete_dataset_flattened_error_async(): ) -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ImportDataRequest, dict,]) +def test_import_data(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1740,10 +1836,6 @@ def test_import_data( assert isinstance(response, future.Future) -def test_import_data_from_dict(): - test_import_data(request_type=dict) - - def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1946,9 +2038,8 @@ async def test_import_data_flattened_error_async(): ) -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ExportDataRequest, dict,]) +def test_export_data(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1972,10 +2063,6 @@ def test_export_data( assert isinstance(response, future.Future) -def test_export_data_from_dict(): - test_export_data(request_type=dict) - - def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2190,9 +2277,8 @@ async def test_export_data_flattened_error_async(): ) -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): +@pytest.mark.parametrize("request_type", [dataset_service.ListDataItemsRequest, dict,]) +def test_list_data_items(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2219,10 +2305,6 @@ def test_list_data_items( assert response.next_page_token == "next_page_token_value" -def test_list_data_items_from_dict(): - test_list_data_items(request_type=dict) - - def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2400,8 +2482,10 @@ async def test_list_data_items_flattened_error_async(): ) -def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_items_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: @@ -2440,8 +2524,10 @@ def test_list_data_items_pager(): assert all(isinstance(i, data_item.DataItem) for i in results) -def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_items_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: @@ -2546,9 +2632,10 @@ async def test_list_data_items_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): +@pytest.mark.parametrize( + "request_type", [dataset_service.GetAnnotationSpecRequest, dict,] +) +def test_get_annotation_spec(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2579,10 +2666,6 @@ def test_get_annotation_spec( assert response.etag == "etag_value" -def test_get_annotation_spec_from_dict(): - test_get_annotation_spec(request_type=dict) - - def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2775,9 +2858,10 @@ async def test_get_annotation_spec_flattened_error_async(): ) -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): +@pytest.mark.parametrize( + "request_type", [dataset_service.ListAnnotationsRequest, dict,] +) +def test_list_annotations(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2804,10 +2888,6 @@ def test_list_annotations( assert response.next_page_token == "next_page_token_value" -def test_list_annotations_from_dict(): - test_list_annotations(request_type=dict) - - def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2985,8 +3065,10 @@ async def test_list_annotations_flattened_error_async(): ) -def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_annotations_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: @@ -3025,8 +3107,10 @@ def test_list_annotations_pager(): assert all(isinstance(i, annotation.Annotation) for i in results) -def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_annotations_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: @@ -3151,6 +3235,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatasetServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.DatasetServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3776,7 +3877,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3841,3 +3942,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 7c59752cf7f..504fd06f2f9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -266,20 +267,20 @@ def test_endpoint_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -348,7 +349,7 @@ def test_endpoint_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -425,6 +426,87 @@ def test_endpoint_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -443,7 +525,7 @@ def test_endpoint_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -457,24 +539,31 @@ def test_endpoint_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -486,6 +575,35 @@ def test_endpoint_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_endpoint_service_client_client_options_from_dict(): with mock.patch( @@ -507,9 +625,10 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.CreateEndpointRequest, dict,] +) +def test_create_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -533,10 +652,6 @@ def test_create_endpoint( assert isinstance(response, future.Future) -def test_create_endpoint_from_dict(): - test_create_endpoint(request_type=dict) - - def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -737,9 +852,8 @@ async def test_create_endpoint_flattened_error_async(): ) -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.GetEndpointRequest, dict,]) +def test_get_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -781,10 +895,6 @@ def test_get_endpoint( ) -def test_get_endpoint_from_dict(): - test_get_endpoint(request_type=dict) - - def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -973,9 +1083,8 @@ async def test_get_endpoint_flattened_error_async(): ) -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.ListEndpointsRequest, dict,]) +def test_list_endpoints(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1002,10 +1111,6 @@ def test_list_endpoints( assert response.next_page_token == "next_page_token_value" -def test_list_endpoints_from_dict(): - test_list_endpoints(request_type=dict) - - def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1183,8 +1288,10 @@ async def test_list_endpoints_flattened_error_async(): ) -def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_endpoints_pager(transport_name: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: @@ -1223,8 +1330,10 @@ def test_list_endpoints_pager(): assert all(isinstance(i, endpoint.Endpoint) for i in results) -def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_endpoints_pages(transport_name: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: @@ -1333,9 +1442,10 @@ async def test_list_endpoints_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.UpdateEndpointRequest, dict,] +) +def test_update_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1377,10 +1487,6 @@ def test_update_endpoint( ) -def test_update_endpoint_from_dict(): - test_update_endpoint(request_type=dict) - - def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1593,9 +1699,10 @@ async def test_update_endpoint_flattened_error_async(): ) -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [endpoint_service.DeleteEndpointRequest, dict,] +) +def test_delete_endpoint(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1619,10 +1726,6 @@ def test_delete_endpoint( assert isinstance(response, future.Future) -def test_delete_endpoint_from_dict(): - test_delete_endpoint(request_type=dict) - - def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1797,9 +1900,8 @@ async def test_delete_endpoint_flattened_error_async(): ) -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.DeployModelRequest, dict,]) +def test_deploy_model(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1823,10 +1925,6 @@ def test_deploy_model( assert isinstance(response, future.Future) -def test_deploy_model_from_dict(): - test_deploy_model(request_type=dict) - - def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2063,9 +2161,8 @@ async def test_deploy_model_flattened_error_async(): ) -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): +@pytest.mark.parametrize("request_type", [endpoint_service.UndeployModelRequest, dict,]) +def test_undeploy_model(request_type, transport: str = "grpc"): client = EndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2089,10 +2186,6 @@ def test_undeploy_model( assert isinstance(response, future.Future) -def test_undeploy_model_from_dict(): - test_undeploy_model(request_type=dict) - - def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2313,6 +2406,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EndpointServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.EndpointServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2918,7 +3028,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2983,3 +3093,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index 364971192df..1b74c435637 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -284,20 +284,20 @@ def test_featurestore_online_serving_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -366,7 +366,7 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -443,6 +443,93 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", + [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, + ], +) +@mock.patch.object( + FeaturestoreOnlineServingServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceClient), +) +@mock.patch.object( + FeaturestoreOnlineServingServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient), +) +def test_featurestore_online_serving_service_client_get_mtls_endpoint_and_cert_source( + client_class, +): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -465,7 +552,7 @@ def test_featurestore_online_serving_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -479,28 +566,31 @@ def test_featurestore_online_serving_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", + grpc_helpers, ), ( FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_featurestore_online_serving_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -512,6 +602,35 @@ def test_featurestore_online_serving_service_client_client_options_credentials_f always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_featurestore_online_serving_service_client_client_options_from_dict(): with mock.patch( @@ -533,10 +652,10 @@ def test_featurestore_online_serving_service_client_client_options_from_dict(): ) -def test_read_feature_values( - transport: str = "grpc", - request_type=featurestore_online_service.ReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_online_service.ReadFeatureValuesRequest, dict,] +) +def test_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreOnlineServingServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -562,10 +681,6 @@ def test_read_feature_values( assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) -def test_read_feature_values_from_dict(): - test_read_feature_values(request_type=dict) - - def test_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -761,10 +876,11 @@ async def test_read_feature_values_flattened_error_async(): ) -def test_streaming_read_feature_values( - transport: str = "grpc", - request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", + [featurestore_online_service.StreamingReadFeatureValuesRequest, dict,], +) +def test_streaming_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreOnlineServingServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -797,10 +913,6 @@ def test_streaming_read_feature_values( ) -def test_streaming_read_feature_values_from_dict(): - test_streaming_read_feature_values(request_type=dict) - - def test_streaming_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1029,6 +1141,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1554,7 +1685,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1619,3 +1750,39 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index 392021dcb0a..f1c3dbc2b77 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -274,20 +275,20 @@ def test_featurestore_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -356,7 +357,7 @@ def test_featurestore_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -433,6 +434,87 @@ def test_featurestore_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [FeaturestoreServiceClient, FeaturestoreServiceAsyncClient] +) +@mock.patch.object( + FeaturestoreServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceClient), +) +@mock.patch.object( + FeaturestoreServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceAsyncClient), +) +def test_featurestore_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -455,7 +537,7 @@ def test_featurestore_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -469,28 +551,31 @@ def test_featurestore_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", + grpc_helpers, ), ( FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_featurestore_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -502,6 +587,35 @@ def test_featurestore_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_featurestore_service_client_client_options_from_dict(): with mock.patch( @@ -523,9 +637,10 @@ def test_featurestore_service_client_client_options_from_dict(): ) -def test_create_featurestore( - transport: str = "grpc", request_type=featurestore_service.CreateFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateFeaturestoreRequest, dict,] +) +def test_create_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -551,10 +666,6 @@ def test_create_featurestore( assert isinstance(response, future.Future) -def test_create_featurestore_from_dict(): - test_create_featurestore(request_type=dict) - - def test_create_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -774,9 +885,10 @@ async def test_create_featurestore_flattened_error_async(): ) -def test_get_featurestore( - transport: str = "grpc", request_type=featurestore_service.GetFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetFeaturestoreRequest, dict,] +) +def test_get_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -807,10 +919,6 @@ def test_get_featurestore( assert response.state == featurestore.Featurestore.State.STABLE -def test_get_featurestore_from_dict(): - test_get_featurestore(request_type=dict) - - def test_get_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -999,9 +1107,10 @@ async def test_get_featurestore_flattened_error_async(): ) -def test_list_featurestores( - transport: str = "grpc", request_type=featurestore_service.ListFeaturestoresRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListFeaturestoresRequest, dict,] +) +def test_list_featurestores(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1030,10 +1139,6 @@ def test_list_featurestores( assert response.next_page_token == "next_page_token_value" -def test_list_featurestores_from_dict(): - test_list_featurestores(request_type=dict) - - def test_list_featurestores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1230,8 +1335,10 @@ async def test_list_featurestores_flattened_error_async(): ) -def test_list_featurestores_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_featurestores_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1275,8 +1382,10 @@ def test_list_featurestores_pager(): assert all(isinstance(i, featurestore.Featurestore) for i in results) -def test_list_featurestores_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_featurestores_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1400,9 +1509,10 @@ async def test_list_featurestores_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_featurestore( - transport: str = "grpc", request_type=featurestore_service.UpdateFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateFeaturestoreRequest, dict,] +) +def test_update_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1428,10 +1538,6 @@ def test_update_featurestore( assert isinstance(response, future.Future) -def test_update_featurestore_from_dict(): - test_update_featurestore(request_type=dict) - - def test_update_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1647,9 +1753,10 @@ async def test_update_featurestore_flattened_error_async(): ) -def test_delete_featurestore( - transport: str = "grpc", request_type=featurestore_service.DeleteFeaturestoreRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteFeaturestoreRequest, dict,] +) +def test_delete_featurestore(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1675,10 +1782,6 @@ def test_delete_featurestore( assert isinstance(response, future.Future) -def test_delete_featurestore_from_dict(): - test_delete_featurestore(request_type=dict) - - def test_delete_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1884,9 +1987,10 @@ async def test_delete_featurestore_flattened_error_async(): ) -def test_create_entity_type( - transport: str = "grpc", request_type=featurestore_service.CreateEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateEntityTypeRequest, dict,] +) +def test_create_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1912,10 +2016,6 @@ def test_create_entity_type( assert isinstance(response, future.Future) -def test_create_entity_type_from_dict(): - test_create_entity_type(request_type=dict) - - def test_create_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2135,9 +2235,10 @@ async def test_create_entity_type_flattened_error_async(): ) -def test_get_entity_type( - transport: str = "grpc", request_type=featurestore_service.GetEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetEntityTypeRequest, dict,] +) +def test_get_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2166,10 +2267,6 @@ def test_get_entity_type( assert response.etag == "etag_value" -def test_get_entity_type_from_dict(): - test_get_entity_type(request_type=dict) - - def test_get_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2356,9 +2453,10 @@ async def test_get_entity_type_flattened_error_async(): ) -def test_list_entity_types( - transport: str = "grpc", request_type=featurestore_service.ListEntityTypesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListEntityTypesRequest, dict,] +) +def test_list_entity_types(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2387,10 +2485,6 @@ def test_list_entity_types( assert response.next_page_token == "next_page_token_value" -def test_list_entity_types_from_dict(): - test_list_entity_types(request_type=dict) - - def test_list_entity_types_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2587,8 +2681,10 @@ async def test_list_entity_types_flattened_error_async(): ) -def test_list_entity_types_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_entity_types_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2629,8 +2725,10 @@ def test_list_entity_types_pager(): assert all(isinstance(i, entity_type.EntityType) for i in results) -def test_list_entity_types_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_entity_types_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2745,9 +2843,10 @@ async def test_list_entity_types_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_entity_type( - transport: str = "grpc", request_type=featurestore_service.UpdateEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateEntityTypeRequest, dict,] +) +def test_update_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2778,10 +2877,6 @@ def test_update_entity_type( assert response.etag == "etag_value" -def test_update_entity_type_from_dict(): - test_update_entity_type(request_type=dict) - - def test_update_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3000,9 +3095,10 @@ async def test_update_entity_type_flattened_error_async(): ) -def test_delete_entity_type( - transport: str = "grpc", request_type=featurestore_service.DeleteEntityTypeRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteEntityTypeRequest, dict,] +) +def test_delete_entity_type(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3028,10 +3124,6 @@ def test_delete_entity_type( assert isinstance(response, future.Future) -def test_delete_entity_type_from_dict(): - test_delete_entity_type(request_type=dict) - - def test_delete_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3237,9 +3329,10 @@ async def test_delete_entity_type_flattened_error_async(): ) -def test_create_feature( - transport: str = "grpc", request_type=featurestore_service.CreateFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.CreateFeatureRequest, dict,] +) +def test_create_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3263,10 +3356,6 @@ def test_create_feature( assert isinstance(response, future.Future) -def test_create_feature_from_dict(): - test_create_feature(request_type=dict) - - def test_create_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3474,10 +3563,10 @@ async def test_create_feature_flattened_error_async(): ) -def test_batch_create_features( - transport: str = "grpc", - request_type=featurestore_service.BatchCreateFeaturesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.BatchCreateFeaturesRequest, dict,] +) +def test_batch_create_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3503,10 +3592,6 @@ def test_batch_create_features( assert isinstance(response, future.Future) -def test_batch_create_features_from_dict(): - test_batch_create_features(request_type=dict) - - def test_batch_create_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3716,9 +3801,10 @@ async def test_batch_create_features_flattened_error_async(): ) -def test_get_feature( - transport: str = "grpc", request_type=featurestore_service.GetFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.GetFeatureRequest, dict,] +) +def test_get_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3751,10 +3837,6 @@ def test_get_feature( assert response.etag == "etag_value" -def test_get_feature_from_dict(): - test_get_feature(request_type=dict) - - def test_get_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3940,9 +4022,10 @@ async def test_get_feature_flattened_error_async(): ) -def test_list_features( - transport: str = "grpc", request_type=featurestore_service.ListFeaturesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ListFeaturesRequest, dict,] +) +def test_list_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3969,10 +4052,6 @@ def test_list_features( assert response.next_page_token == "next_page_token_value" -def test_list_features_from_dict(): - test_list_features(request_type=dict) - - def test_list_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4157,8 +4236,10 @@ async def test_list_features_flattened_error_async(): ) -def test_list_features_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_features_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: @@ -4193,8 +4274,10 @@ def test_list_features_pager(): assert all(isinstance(i, feature.Feature) for i in results) -def test_list_features_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_features_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: @@ -4291,9 +4374,10 @@ async def test_list_features_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_feature( - transport: str = "grpc", request_type=featurestore_service.UpdateFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.UpdateFeatureRequest, dict,] +) +def test_update_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4326,10 +4410,6 @@ def test_update_feature( assert response.etag == "etag_value" -def test_update_feature_from_dict(): - test_update_feature(request_type=dict) - - def test_update_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4536,9 +4616,10 @@ async def test_update_feature_flattened_error_async(): ) -def test_delete_feature( - transport: str = "grpc", request_type=featurestore_service.DeleteFeatureRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.DeleteFeatureRequest, dict,] +) +def test_delete_feature(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4562,10 +4643,6 @@ def test_delete_feature( assert isinstance(response, future.Future) -def test_delete_feature_from_dict(): - test_delete_feature(request_type=dict) - - def test_delete_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4747,10 +4824,10 @@ async def test_delete_feature_flattened_error_async(): ) -def test_import_feature_values( - transport: str = "grpc", - request_type=featurestore_service.ImportFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ImportFeatureValuesRequest, dict,] +) +def test_import_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4776,10 +4853,6 @@ def test_import_feature_values( assert isinstance(response, future.Future) -def test_import_feature_values_from_dict(): - test_import_feature_values(request_type=dict) - - def test_import_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4975,10 +5048,10 @@ async def test_import_feature_values_flattened_error_async(): ) -def test_batch_read_feature_values( - transport: str = "grpc", - request_type=featurestore_service.BatchReadFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.BatchReadFeatureValuesRequest, dict,] +) +def test_batch_read_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5004,10 +5077,6 @@ def test_batch_read_feature_values( assert isinstance(response, future.Future) -def test_batch_read_feature_values_from_dict(): - test_batch_read_feature_values(request_type=dict) - - def test_batch_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5209,10 +5278,10 @@ async def test_batch_read_feature_values_flattened_error_async(): ) -def test_export_feature_values( - transport: str = "grpc", - request_type=featurestore_service.ExportFeatureValuesRequest, -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.ExportFeatureValuesRequest, dict,] +) +def test_export_feature_values(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5238,10 +5307,6 @@ def test_export_feature_values( assert isinstance(response, future.Future) -def test_export_feature_values_from_dict(): - test_export_feature_values(request_type=dict) - - def test_export_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5437,9 +5502,10 @@ async def test_export_feature_values_flattened_error_async(): ) -def test_search_features( - transport: str = "grpc", request_type=featurestore_service.SearchFeaturesRequest -): +@pytest.mark.parametrize( + "request_type", [featurestore_service.SearchFeaturesRequest, dict,] +) +def test_search_features(request_type, transport: str = "grpc"): client = FeaturestoreServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5466,10 +5532,6 @@ def test_search_features( assert response.next_page_token == "next_page_token_value" -def test_search_features_from_dict(): - test_search_features(request_type=dict) - - def test_search_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5668,8 +5730,10 @@ async def test_search_features_flattened_error_async(): ) -def test_search_features_pager(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_features_pager(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: @@ -5704,8 +5768,10 @@ def test_search_features_pager(): assert all(isinstance(i, feature.Feature) for i in results) -def test_search_features_pages(): - client = FeaturestoreServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_features_pages(transport_name: str = "grpc"): + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: @@ -5822,6 +5888,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FeaturestoreServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.FeaturestoreServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -6437,7 +6520,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -6502,3 +6585,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index 9e96140fe4d..b8abbab31d1 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -265,20 +266,20 @@ def test_index_endpoint_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -347,7 +348,7 @@ def test_index_endpoint_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -424,6 +425,87 @@ def test_index_endpoint_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [IndexEndpointServiceClient, IndexEndpointServiceAsyncClient] +) +@mock.patch.object( + IndexEndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceClient), +) +@mock.patch.object( + IndexEndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceAsyncClient), +) +def test_index_endpoint_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -446,7 +528,7 @@ def test_index_endpoint_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -460,28 +542,31 @@ def test_index_endpoint_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", + grpc_helpers, ), ( IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_index_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -493,6 +578,35 @@ def test_index_endpoint_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_index_endpoint_service_client_client_options_from_dict(): with mock.patch( @@ -514,10 +628,10 @@ def test_index_endpoint_service_client_client_options_from_dict(): ) -def test_create_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.CreateIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.CreateIndexEndpointRequest, dict,] +) +def test_create_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -543,10 +657,6 @@ def test_create_index_endpoint( assert isinstance(response, future.Future) -def test_create_index_endpoint_from_dict(): - test_create_index_endpoint(request_type=dict) - - def test_create_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -756,9 +866,10 @@ async def test_create_index_endpoint_flattened_error_async(): ) -def test_get_index_endpoint( - transport: str = "grpc", request_type=index_endpoint_service.GetIndexEndpointRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.GetIndexEndpointRequest, dict,] +) +def test_get_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -797,10 +908,6 @@ def test_get_index_endpoint( assert response.enable_private_service_connect is True -def test_get_index_endpoint_from_dict(): - test_get_index_endpoint(request_type=dict) - - def test_get_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1007,10 +1114,10 @@ async def test_get_index_endpoint_flattened_error_async(): ) -def test_list_index_endpoints( - transport: str = "grpc", - request_type=index_endpoint_service.ListIndexEndpointsRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.ListIndexEndpointsRequest, dict,] +) +def test_list_index_endpoints(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1039,10 +1146,6 @@ def test_list_index_endpoints( assert response.next_page_token == "next_page_token_value" -def test_list_index_endpoints_from_dict(): - test_list_index_endpoints(request_type=dict) - - def test_list_index_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1239,9 +1342,9 @@ async def test_list_index_endpoints_flattened_error_async(): ) -def test_list_index_endpoints_pager(): +def test_list_index_endpoints_pager(transport_name: str = "grpc"): client = IndexEndpointServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1287,9 +1390,9 @@ def test_list_index_endpoints_pager(): assert all(isinstance(i, index_endpoint.IndexEndpoint) for i in results) -def test_list_index_endpoints_pages(): +def test_list_index_endpoints_pages(transport_name: str = "grpc"): client = IndexEndpointServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1417,10 +1520,10 @@ async def test_list_index_endpoints_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.UpdateIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.UpdateIndexEndpointRequest, dict,] +) +def test_update_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1459,10 +1562,6 @@ def test_update_index_endpoint( assert response.enable_private_service_connect is True -def test_update_index_endpoint_from_dict(): - test_update_index_endpoint(request_type=dict) - - def test_update_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1691,10 +1790,10 @@ async def test_update_index_endpoint_flattened_error_async(): ) -def test_delete_index_endpoint( - transport: str = "grpc", - request_type=index_endpoint_service.DeleteIndexEndpointRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.DeleteIndexEndpointRequest, dict,] +) +def test_delete_index_endpoint(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1720,10 +1819,6 @@ def test_delete_index_endpoint( assert isinstance(response, future.Future) -def test_delete_index_endpoint_from_dict(): - test_delete_index_endpoint(request_type=dict) - - def test_delete_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1917,9 +2012,10 @@ async def test_delete_index_endpoint_flattened_error_async(): ) -def test_deploy_index( - transport: str = "grpc", request_type=index_endpoint_service.DeployIndexRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.DeployIndexRequest, dict,] +) +def test_deploy_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1943,10 +2039,6 @@ def test_deploy_index( assert isinstance(response, future.Future) -def test_deploy_index_from_dict(): - test_deploy_index(request_type=dict) - - def test_deploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2148,9 +2240,10 @@ async def test_deploy_index_flattened_error_async(): ) -def test_undeploy_index( - transport: str = "grpc", request_type=index_endpoint_service.UndeployIndexRequest -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.UndeployIndexRequest, dict,] +) +def test_undeploy_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2174,10 +2267,6 @@ def test_undeploy_index( assert isinstance(response, future.Future) -def test_undeploy_index_from_dict(): - test_undeploy_index(request_type=dict) - - def test_undeploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2379,10 +2468,10 @@ async def test_undeploy_index_flattened_error_async(): ) -def test_mutate_deployed_index( - transport: str = "grpc", - request_type=index_endpoint_service.MutateDeployedIndexRequest, -): +@pytest.mark.parametrize( + "request_type", [index_endpoint_service.MutateDeployedIndexRequest, dict,] +) +def test_mutate_deployed_index(request_type, transport: str = "grpc"): client = IndexEndpointServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2408,10 +2497,6 @@ def test_mutate_deployed_index( assert isinstance(response, future.Future) -def test_mutate_deployed_index_from_dict(): - test_mutate_deployed_index(request_type=dict) - - def test_mutate_deployed_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2645,6 +2730,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.IndexEndpointServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3207,7 +3311,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3272,3 +3376,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py index f54ca8e4ca3..4babf78b589 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -250,20 +251,20 @@ def test_index_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -320,7 +321,7 @@ def test_index_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -397,6 +398,83 @@ def test_index_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [IndexServiceClient, IndexServiceAsyncClient]) +@mock.patch.object( + IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient) +) +@mock.patch.object( + IndexServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexServiceAsyncClient), +) +def test_index_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -415,7 +493,7 @@ def test_index_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -429,24 +507,31 @@ def test_index_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + ( + IndexServiceClient, + transports.IndexServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_index_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -458,6 +543,35 @@ def test_index_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_index_service_client_client_options_from_dict(): with mock.patch( @@ -477,9 +591,8 @@ def test_index_service_client_client_options_from_dict(): ) -def test_create_index( - transport: str = "grpc", request_type=index_service.CreateIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.CreateIndexRequest, dict,]) +def test_create_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -503,10 +616,6 @@ def test_create_index( assert isinstance(response, future.Future) -def test_create_index_from_dict(): - test_create_index(request_type=dict) - - def test_create_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -689,7 +798,8 @@ async def test_create_index_flattened_error_async(): ) -def test_get_index(transport: str = "grpc", request_type=index_service.GetIndexRequest): +@pytest.mark.parametrize("request_type", [index_service.GetIndexRequest, dict,]) +def test_get_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -724,10 +834,6 @@ def test_get_index(transport: str = "grpc", request_type=index_service.GetIndexR assert response.etag == "etag_value" -def test_get_index_from_dict(): - test_get_index(request_type=dict) - - def test_get_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -903,9 +1009,8 @@ async def test_get_index_flattened_error_async(): ) -def test_list_indexes( - transport: str = "grpc", request_type=index_service.ListIndexesRequest -): +@pytest.mark.parametrize("request_type", [index_service.ListIndexesRequest, dict,]) +def test_list_indexes(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -932,10 +1037,6 @@ def test_list_indexes( assert response.next_page_token == "next_page_token_value" -def test_list_indexes_from_dict(): - test_list_indexes(request_type=dict) - - def test_list_indexes_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1105,8 +1206,10 @@ async def test_list_indexes_flattened_error_async(): ) -def test_list_indexes_pager(): - client = IndexServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_indexes_pager(transport_name: str = "grpc"): + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -1137,8 +1240,10 @@ def test_list_indexes_pager(): assert all(isinstance(i, index.Index) for i in results) -def test_list_indexes_pages(): - client = IndexServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_indexes_pages(transport_name: str = "grpc"): + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -1219,9 +1324,8 @@ async def test_list_indexes_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_index( - transport: str = "grpc", request_type=index_service.UpdateIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.UpdateIndexRequest, dict,]) +def test_update_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1245,10 +1349,6 @@ def test_update_index( assert isinstance(response, future.Future) -def test_update_index_from_dict(): - test_update_index(request_type=dict) - - def test_update_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1433,9 +1533,8 @@ async def test_update_index_flattened_error_async(): ) -def test_delete_index( - transport: str = "grpc", request_type=index_service.DeleteIndexRequest -): +@pytest.mark.parametrize("request_type", [index_service.DeleteIndexRequest, dict,]) +def test_delete_index(request_type, transport: str = "grpc"): client = IndexServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1459,10 +1558,6 @@ def test_delete_index( assert isinstance(response, future.Future) -def test_delete_index_from_dict(): - test_delete_index(request_type=dict) - - def test_delete_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1651,6 +1746,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = IndexServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.IndexServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2193,7 +2305,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2258,3 +2370,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 4538fb5fbe2..1190c197577 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -281,20 +282,20 @@ def test_job_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -351,7 +352,7 @@ def test_job_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -428,6 +429,83 @@ def test_job_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -446,7 +524,7 @@ def test_job_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -460,24 +538,26 @@ def test_job_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", grpc_helpers), ( JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -489,6 +569,35 @@ def test_job_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_job_service_client_client_options_from_dict(): with mock.patch( @@ -508,9 +617,8 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.CreateCustomJobRequest, dict,]) +def test_create_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -543,10 +651,6 @@ def test_create_custom_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_custom_job_from_dict(): - test_create_custom_job(request_type=dict) - - def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -750,9 +854,8 @@ async def test_create_custom_job_flattened_error_async(): ) -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.GetCustomJobRequest, dict,]) +def test_get_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -783,10 +886,6 @@ def test_get_custom_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_custom_job_from_dict(): - test_get_custom_job(request_type=dict) - - def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -962,9 +1061,8 @@ async def test_get_custom_job_flattened_error_async(): ) -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): +@pytest.mark.parametrize("request_type", [job_service.ListCustomJobsRequest, dict,]) +def test_list_custom_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -991,10 +1089,6 @@ def test_list_custom_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_custom_jobs_from_dict(): - test_list_custom_jobs(request_type=dict) - - def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1164,8 +1258,10 @@ async def test_list_custom_jobs_flattened_error_async(): ) -def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_custom_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: @@ -1202,8 +1298,10 @@ def test_list_custom_jobs_pager(): assert all(isinstance(i, custom_job.CustomJob) for i in results) -def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_custom_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: @@ -1302,9 +1400,8 @@ async def test_list_custom_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.DeleteCustomJobRequest, dict,]) +def test_delete_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1330,10 +1427,6 @@ def test_delete_custom_job( assert isinstance(response, future.Future) -def test_delete_custom_job_from_dict(): - test_delete_custom_job(request_type=dict) - - def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1514,9 +1607,8 @@ async def test_delete_custom_job_flattened_error_async(): ) -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.CancelCustomJobRequest, dict,]) +def test_cancel_custom_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1542,10 +1634,6 @@ def test_cancel_custom_job( assert response is None -def test_cancel_custom_job_from_dict(): - test_cancel_custom_job(request_type=dict) - - def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1720,9 +1808,10 @@ async def test_cancel_custom_job_flattened_error_async(): ) -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateDataLabelingJobRequest, dict,] +) +def test_create_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1767,10 +1856,6 @@ def test_create_data_labeling_job( assert response.specialist_pools == ["specialist_pools_value"] -def test_create_data_labeling_job_from_dict(): - test_create_data_labeling_job(request_type=dict) - - def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1987,9 +2072,8 @@ async def test_create_data_labeling_job_flattened_error_async(): ) -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): +@pytest.mark.parametrize("request_type", [job_service.GetDataLabelingJobRequest, dict,]) +def test_get_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2034,10 +2118,6 @@ def test_get_data_labeling_job( assert response.specialist_pools == ["specialist_pools_value"] -def test_get_data_labeling_job_from_dict(): - test_get_data_labeling_job(request_type=dict) - - def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2237,9 +2317,10 @@ async def test_get_data_labeling_job_flattened_error_async(): ) -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.ListDataLabelingJobsRequest, dict,] +) +def test_list_data_labeling_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2268,10 +2349,6 @@ def test_list_data_labeling_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_data_labeling_jobs_from_dict(): - test_list_data_labeling_jobs(request_type=dict) - - def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2456,8 +2533,10 @@ async def test_list_data_labeling_jobs_flattened_error_async(): ) -def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_labeling_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2502,8 +2581,10 @@ def test_list_data_labeling_jobs_pager(): assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) -def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_data_labeling_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2626,9 +2707,10 @@ async def test_list_data_labeling_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteDataLabelingJobRequest, dict,] +) +def test_delete_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2654,10 +2736,6 @@ def test_delete_data_labeling_job( assert isinstance(response, future.Future) -def test_delete_data_labeling_job_from_dict(): - test_delete_data_labeling_job(request_type=dict) - - def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2839,9 +2917,10 @@ async def test_delete_data_labeling_job_flattened_error_async(): ) -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelDataLabelingJobRequest, dict,] +) +def test_cancel_data_labeling_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2867,10 +2946,6 @@ def test_cancel_data_labeling_job( assert response is None -def test_cancel_data_labeling_job_from_dict(): - test_cancel_data_labeling_job(request_type=dict) - - def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3046,10 +3121,10 @@ async def test_cancel_data_labeling_job_flattened_error_async(): ) -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateHyperparameterTuningJobRequest, dict,] +) +def test_create_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3088,10 +3163,6 @@ def test_create_hyperparameter_tuning_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_hyperparameter_tuning_job_from_dict(): - test_create_hyperparameter_tuning_job(request_type=dict) - - def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3314,9 +3385,10 @@ async def test_create_hyperparameter_tuning_job_flattened_error_async(): ) -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.GetHyperparameterTuningJobRequest, dict,] +) +def test_get_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3355,10 +3427,6 @@ def test_get_hyperparameter_tuning_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_hyperparameter_tuning_job_from_dict(): - test_get_hyperparameter_tuning_job(request_type=dict) - - def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3553,10 +3621,10 @@ async def test_get_hyperparameter_tuning_job_flattened_error_async(): ) -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ListHyperparameterTuningJobsRequest, dict,] +) +def test_list_hyperparameter_tuning_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3585,10 +3653,6 @@ def test_list_hyperparameter_tuning_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_hyperparameter_tuning_jobs_from_dict(): - test_list_hyperparameter_tuning_jobs(request_type=dict) - - def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3773,8 +3837,10 @@ async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): ) -def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_hyperparameter_tuning_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3824,8 +3890,10 @@ def test_list_hyperparameter_tuning_jobs_pager(): ) -def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_hyperparameter_tuning_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3959,10 +4027,10 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteHyperparameterTuningJobRequest, dict,] +) +def test_delete_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3988,10 +4056,6 @@ def test_delete_hyperparameter_tuning_job( assert isinstance(response, future.Future) -def test_delete_hyperparameter_tuning_job_from_dict(): - test_delete_hyperparameter_tuning_job(request_type=dict) - - def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4173,10 +4237,10 @@ async def test_delete_hyperparameter_tuning_job_flattened_error_async(): ) -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelHyperparameterTuningJobRequest, dict,] +) +def test_cancel_hyperparameter_tuning_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4202,10 +4266,6 @@ def test_cancel_hyperparameter_tuning_job( assert response is None -def test_cancel_hyperparameter_tuning_job_from_dict(): - test_cancel_hyperparameter_tuning_job(request_type=dict) - - def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4381,9 +4441,10 @@ async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): ) -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateBatchPredictionJobRequest, dict,] +) +def test_create_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4420,10 +4481,6 @@ def test_create_batch_prediction_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_create_batch_prediction_job_from_dict(): - test_create_batch_prediction_job(request_type=dict) - - def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4640,9 +4697,10 @@ async def test_create_batch_prediction_job_flattened_error_async(): ) -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.GetBatchPredictionJobRequest, dict,] +) +def test_get_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4679,10 +4737,6 @@ def test_get_batch_prediction_job( assert response.state == job_state.JobState.JOB_STATE_QUEUED -def test_get_batch_prediction_job_from_dict(): - test_get_batch_prediction_job(request_type=dict) - - def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4875,9 +4929,10 @@ async def test_get_batch_prediction_job_flattened_error_async(): ) -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.ListBatchPredictionJobsRequest, dict,] +) +def test_list_batch_prediction_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4906,10 +4961,6 @@ def test_list_batch_prediction_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_batch_prediction_jobs_from_dict(): - test_list_batch_prediction_jobs(request_type=dict) - - def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5094,8 +5145,10 @@ async def test_list_batch_prediction_jobs_flattened_error_async(): ) -def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_batch_prediction_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -5142,8 +5195,10 @@ def test_list_batch_prediction_jobs_pager(): ) -def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_batch_prediction_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -5268,9 +5323,10 @@ async def test_list_batch_prediction_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteBatchPredictionJobRequest, dict,] +) +def test_delete_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5296,10 +5352,6 @@ def test_delete_batch_prediction_job( assert isinstance(response, future.Future) -def test_delete_batch_prediction_job_from_dict(): - test_delete_batch_prediction_job(request_type=dict) - - def test_delete_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5481,9 +5533,10 @@ async def test_delete_batch_prediction_job_flattened_error_async(): ) -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): +@pytest.mark.parametrize( + "request_type", [job_service.CancelBatchPredictionJobRequest, dict,] +) +def test_cancel_batch_prediction_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5509,10 +5562,6 @@ def test_cancel_batch_prediction_job( assert response is None -def test_cancel_batch_prediction_job_from_dict(): - test_cancel_batch_prediction_job(request_type=dict) - - def test_cancel_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5688,10 +5737,10 @@ async def test_cancel_batch_prediction_job_flattened_error_async(): ) -def test_create_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.CreateModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.CreateModelDeploymentMonitoringJobRequest, dict,] +) +def test_create_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5739,10 +5788,6 @@ def test_create_model_deployment_monitoring_job( assert response.enable_monitoring_pipeline_logs is True -def test_create_model_deployment_monitoring_job_from_dict(): - test_create_model_deployment_monitoring_job(request_type=dict) - - def test_create_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5980,9 +6025,12 @@ async def test_create_model_deployment_monitoring_job_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, dict,], +) def test_search_model_deployment_monitoring_stats_anomalies( - transport: str = "grpc", - request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + request_type, transport: str = "grpc" ): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6018,10 +6066,6 @@ def test_search_model_deployment_monitoring_stats_anomalies( assert response.next_page_token == "next_page_token_value" -def test_search_model_deployment_monitoring_stats_anomalies_from_dict(): - test_search_model_deployment_monitoring_stats_anomalies(request_type=dict) - - def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6252,8 +6296,12 @@ async def test_search_model_deployment_monitoring_stats_anomalies_flattened_erro ) -def test_search_model_deployment_monitoring_stats_anomalies_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_model_deployment_monitoring_stats_anomalies_pager( + transport_name: str = "grpc", +): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6308,8 +6356,12 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager(): ) -def test_search_model_deployment_monitoring_stats_anomalies_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_model_deployment_monitoring_stats_anomalies_pages( + transport_name: str = "grpc", +): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6450,10 +6502,10 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.GetModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.GetModelDeploymentMonitoringJobRequest, dict,] +) +def test_get_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6501,10 +6553,6 @@ def test_get_model_deployment_monitoring_job( assert response.enable_monitoring_pipeline_logs is True -def test_get_model_deployment_monitoring_job_from_dict(): - test_get_model_deployment_monitoring_job(request_type=dict) - - def test_get_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6714,10 +6762,10 @@ async def test_get_model_deployment_monitoring_job_flattened_error_async(): ) -def test_list_model_deployment_monitoring_jobs( - transport: str = "grpc", - request_type=job_service.ListModelDeploymentMonitoringJobsRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ListModelDeploymentMonitoringJobsRequest, dict,] +) +def test_list_model_deployment_monitoring_jobs(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6746,10 +6794,6 @@ def test_list_model_deployment_monitoring_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_model_deployment_monitoring_jobs_from_dict(): - test_list_model_deployment_monitoring_jobs(request_type=dict) - - def test_list_model_deployment_monitoring_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6938,8 +6982,10 @@ async def test_list_model_deployment_monitoring_jobs_flattened_error_async(): ) -def test_list_model_deployment_monitoring_jobs_pager(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_deployment_monitoring_jobs_pager(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6989,8 +7035,10 @@ def test_list_model_deployment_monitoring_jobs_pager(): ) -def test_list_model_deployment_monitoring_jobs_pages(): - client = JobServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_deployment_monitoring_jobs_pages(transport_name: str = "grpc"): + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -7124,10 +7172,10 @@ async def test_list_model_deployment_monitoring_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.UpdateModelDeploymentMonitoringJobRequest, dict,] +) +def test_update_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7153,10 +7201,6 @@ def test_update_model_deployment_monitoring_job( assert isinstance(response, future.Future) -def test_update_model_deployment_monitoring_job_from_dict(): - test_update_model_deployment_monitoring_job(request_type=dict) - - def test_update_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7376,10 +7420,10 @@ async def test_update_model_deployment_monitoring_job_flattened_error_async(): ) -def test_delete_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.DeleteModelDeploymentMonitoringJobRequest, dict,] +) +def test_delete_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7405,10 +7449,6 @@ def test_delete_model_deployment_monitoring_job( assert isinstance(response, future.Future) -def test_delete_model_deployment_monitoring_job_from_dict(): - test_delete_model_deployment_monitoring_job(request_type=dict) - - def test_delete_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7592,10 +7632,10 @@ async def test_delete_model_deployment_monitoring_job_flattened_error_async(): ) -def test_pause_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.PauseModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.PauseModelDeploymentMonitoringJobRequest, dict,] +) +def test_pause_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7621,10 +7661,6 @@ def test_pause_model_deployment_monitoring_job( assert response is None -def test_pause_model_deployment_monitoring_job_from_dict(): - test_pause_model_deployment_monitoring_job(request_type=dict) - - def test_pause_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7802,10 +7838,10 @@ async def test_pause_model_deployment_monitoring_job_flattened_error_async(): ) -def test_resume_model_deployment_monitoring_job( - transport: str = "grpc", - request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, -): +@pytest.mark.parametrize( + "request_type", [job_service.ResumeModelDeploymentMonitoringJobRequest, dict,] +) +def test_resume_model_deployment_monitoring_job(request_type, transport: str = "grpc"): client = JobServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7831,10 +7867,6 @@ def test_resume_model_deployment_monitoring_job( assert response is None -def test_resume_model_deployment_monitoring_job_from_dict(): - test_resume_model_deployment_monitoring_job(request_type=dict) - - def test_resume_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8032,6 +8064,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = JobServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.JobServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -8819,7 +8868,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -8884,3 +8933,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (JobServiceClient, transports.JobServiceGrpcTransport), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index 1407bbd8178..9768bec18bc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -271,20 +272,20 @@ def test_metadata_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -353,7 +354,7 @@ def test_metadata_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -430,6 +431,87 @@ def test_metadata_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient] +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) +def test_metadata_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -448,7 +530,7 @@ def test_metadata_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -462,24 +544,31 @@ def test_metadata_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_metadata_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -491,6 +580,35 @@ def test_metadata_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_metadata_service_client_client_options_from_dict(): with mock.patch( @@ -512,9 +630,10 @@ def test_metadata_service_client_client_options_from_dict(): ) -def test_create_metadata_store( - transport: str = "grpc", request_type=metadata_service.CreateMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateMetadataStoreRequest, dict,] +) +def test_create_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -540,10 +659,6 @@ def test_create_metadata_store( assert isinstance(response, future.Future) -def test_create_metadata_store_from_dict(): - test_create_metadata_store(request_type=dict) - - def test_create_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -757,9 +872,10 @@ async def test_create_metadata_store_flattened_error_async(): ) -def test_get_metadata_store( - transport: str = "grpc", request_type=metadata_service.GetMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.GetMetadataStoreRequest, dict,] +) +def test_get_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -789,10 +905,6 @@ def test_get_metadata_store( assert response.description == "description_value" -def test_get_metadata_store_from_dict(): - test_get_metadata_store(request_type=dict) - - def test_get_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -984,9 +1096,10 @@ async def test_get_metadata_store_flattened_error_async(): ) -def test_list_metadata_stores( - transport: str = "grpc", request_type=metadata_service.ListMetadataStoresRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListMetadataStoresRequest, dict,] +) +def test_list_metadata_stores(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1015,10 +1128,6 @@ def test_list_metadata_stores( assert response.next_page_token == "next_page_token_value" -def test_list_metadata_stores_from_dict(): - test_list_metadata_stores(request_type=dict) - - def test_list_metadata_stores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1209,8 +1318,10 @@ async def test_list_metadata_stores_flattened_error_async(): ) -def test_list_metadata_stores_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_stores_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1255,8 +1366,10 @@ def test_list_metadata_stores_pager(): assert all(isinstance(i, metadata_store.MetadataStore) for i in results) -def test_list_metadata_stores_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_stores_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1383,9 +1496,10 @@ async def test_list_metadata_stores_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_metadata_store( - transport: str = "grpc", request_type=metadata_service.DeleteMetadataStoreRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteMetadataStoreRequest, dict,] +) +def test_delete_metadata_store(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1411,10 +1525,6 @@ def test_delete_metadata_store( assert isinstance(response, future.Future) -def test_delete_metadata_store_from_dict(): - test_delete_metadata_store(request_type=dict) - - def test_delete_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1602,9 +1712,10 @@ async def test_delete_metadata_store_flattened_error_async(): ) -def test_create_artifact( - transport: str = "grpc", request_type=metadata_service.CreateArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateArtifactRequest, dict,] +) +def test_create_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1645,10 +1756,6 @@ def test_create_artifact( assert response.description == "description_value" -def test_create_artifact_from_dict(): - test_create_artifact(request_type=dict) - - def test_create_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1866,9 +1973,8 @@ async def test_create_artifact_flattened_error_async(): ) -def test_get_artifact( - transport: str = "grpc", request_type=metadata_service.GetArtifactRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetArtifactRequest, dict,]) +def test_get_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1909,10 +2015,6 @@ def test_get_artifact( assert response.description == "description_value" -def test_get_artifact_from_dict(): - test_get_artifact(request_type=dict) - - def test_get_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2100,9 +2202,8 @@ async def test_get_artifact_flattened_error_async(): ) -def test_list_artifacts( - transport: str = "grpc", request_type=metadata_service.ListArtifactsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.ListArtifactsRequest, dict,]) +def test_list_artifacts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2129,10 +2230,6 @@ def test_list_artifacts( assert response.next_page_token == "next_page_token_value" -def test_list_artifacts_from_dict(): - test_list_artifacts(request_type=dict) - - def test_list_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2310,8 +2407,10 @@ async def test_list_artifacts_flattened_error_async(): ) -def test_list_artifacts_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_artifacts_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: @@ -2350,8 +2449,10 @@ def test_list_artifacts_pager(): assert all(isinstance(i, artifact.Artifact) for i in results) -def test_list_artifacts_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_artifacts_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: @@ -2460,9 +2561,10 @@ async def test_list_artifacts_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_artifact( - transport: str = "grpc", request_type=metadata_service.UpdateArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.UpdateArtifactRequest, dict,] +) +def test_update_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2503,10 +2605,6 @@ def test_update_artifact( assert response.description == "description_value" -def test_update_artifact_from_dict(): - test_update_artifact(request_type=dict) - - def test_update_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2718,9 +2816,10 @@ async def test_update_artifact_flattened_error_async(): ) -def test_delete_artifact( - transport: str = "grpc", request_type=metadata_service.DeleteArtifactRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteArtifactRequest, dict,] +) +def test_delete_artifact(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2744,10 +2843,6 @@ def test_delete_artifact( assert isinstance(response, future.Future) -def test_delete_artifact_from_dict(): - test_delete_artifact(request_type=dict) - - def test_delete_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2922,9 +3017,10 @@ async def test_delete_artifact_flattened_error_async(): ) -def test_purge_artifacts( - transport: str = "grpc", request_type=metadata_service.PurgeArtifactsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.PurgeArtifactsRequest, dict,] +) +def test_purge_artifacts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2948,10 +3044,6 @@ def test_purge_artifacts( assert isinstance(response, future.Future) -def test_purge_artifacts_from_dict(): - test_purge_artifacts(request_type=dict) - - def test_purge_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3126,9 +3218,8 @@ async def test_purge_artifacts_flattened_error_async(): ) -def test_create_context( - transport: str = "grpc", request_type=metadata_service.CreateContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.CreateContextRequest, dict,]) +def test_create_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3167,10 +3258,6 @@ def test_create_context( assert response.description == "description_value" -def test_create_context_from_dict(): - test_create_context(request_type=dict) - - def test_create_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3382,9 +3469,8 @@ async def test_create_context_flattened_error_async(): ) -def test_get_context( - transport: str = "grpc", request_type=metadata_service.GetContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetContextRequest, dict,]) +def test_get_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3423,10 +3509,6 @@ def test_get_context( assert response.description == "description_value" -def test_get_context_from_dict(): - test_get_context(request_type=dict) - - def test_get_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3612,9 +3694,8 @@ async def test_get_context_flattened_error_async(): ) -def test_list_contexts( - transport: str = "grpc", request_type=metadata_service.ListContextsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.ListContextsRequest, dict,]) +def test_list_contexts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3641,10 +3722,6 @@ def test_list_contexts( assert response.next_page_token == "next_page_token_value" -def test_list_contexts_from_dict(): - test_list_contexts(request_type=dict) - - def test_list_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3822,8 +3899,10 @@ async def test_list_contexts_flattened_error_async(): ) -def test_list_contexts_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_contexts_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: @@ -3856,8 +3935,10 @@ def test_list_contexts_pager(): assert all(isinstance(i, context.Context) for i in results) -def test_list_contexts_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_contexts_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: @@ -3948,9 +4029,8 @@ async def test_list_contexts_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_context( - transport: str = "grpc", request_type=metadata_service.UpdateContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.UpdateContextRequest, dict,]) +def test_update_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3989,10 +4069,6 @@ def test_update_context( assert response.description == "description_value" -def test_update_context_from_dict(): - test_update_context(request_type=dict) - - def test_update_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4198,9 +4274,8 @@ async def test_update_context_flattened_error_async(): ) -def test_delete_context( - transport: str = "grpc", request_type=metadata_service.DeleteContextRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.DeleteContextRequest, dict,]) +def test_delete_context(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4224,10 +4299,6 @@ def test_delete_context( assert isinstance(response, future.Future) -def test_delete_context_from_dict(): - test_delete_context(request_type=dict) - - def test_delete_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4402,9 +4473,8 @@ async def test_delete_context_flattened_error_async(): ) -def test_purge_contexts( - transport: str = "grpc", request_type=metadata_service.PurgeContextsRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.PurgeContextsRequest, dict,]) +def test_purge_contexts(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4428,10 +4498,6 @@ def test_purge_contexts( assert isinstance(response, future.Future) -def test_purge_contexts_from_dict(): - test_purge_contexts(request_type=dict) - - def test_purge_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4606,10 +4672,10 @@ async def test_purge_contexts_flattened_error_async(): ) -def test_add_context_artifacts_and_executions( - transport: str = "grpc", - request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddContextArtifactsAndExecutionsRequest, dict,] +) +def test_add_context_artifacts_and_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4637,10 +4703,6 @@ def test_add_context_artifacts_and_executions( ) -def test_add_context_artifacts_and_executions_from_dict(): - test_add_context_artifacts_and_executions(request_type=dict) - - def test_add_context_artifacts_and_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4856,9 +4918,10 @@ async def test_add_context_artifacts_and_executions_flattened_error_async(): ) -def test_add_context_children( - transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddContextChildrenRequest, dict,] +) +def test_add_context_children(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4884,10 +4947,6 @@ def test_add_context_children( assert isinstance(response, metadata_service.AddContextChildrenResponse) -def test_add_context_children_from_dict(): - test_add_context_children(request_type=dict) - - def test_add_context_children_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5089,10 +5148,10 @@ async def test_add_context_children_flattened_error_async(): ) -def test_query_context_lineage_subgraph( - transport: str = "grpc", - request_type=metadata_service.QueryContextLineageSubgraphRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryContextLineageSubgraphRequest, dict,] +) +def test_query_context_lineage_subgraph(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5118,10 +5177,6 @@ def test_query_context_lineage_subgraph( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_context_lineage_subgraph_from_dict(): - test_query_context_lineage_subgraph(request_type=dict) - - def test_query_context_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5311,9 +5366,10 @@ async def test_query_context_lineage_subgraph_flattened_error_async(): ) -def test_create_execution( - transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateExecutionRequest, dict,] +) +def test_create_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5352,10 +5408,6 @@ def test_create_execution( assert response.description == "description_value" -def test_create_execution_from_dict(): - test_create_execution(request_type=dict) - - def test_create_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5572,9 +5624,8 @@ async def test_create_execution_flattened_error_async(): ) -def test_get_execution( - transport: str = "grpc", request_type=metadata_service.GetExecutionRequest -): +@pytest.mark.parametrize("request_type", [metadata_service.GetExecutionRequest, dict,]) +def test_get_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5613,10 +5664,6 @@ def test_get_execution( assert response.description == "description_value" -def test_get_execution_from_dict(): - test_get_execution(request_type=dict) - - def test_get_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5802,9 +5849,10 @@ async def test_get_execution_flattened_error_async(): ) -def test_list_executions( - transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListExecutionsRequest, dict,] +) +def test_list_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5831,10 +5879,6 @@ def test_list_executions( assert response.next_page_token == "next_page_token_value" -def test_list_executions_from_dict(): - test_list_executions(request_type=dict) - - def test_list_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6012,8 +6056,10 @@ async def test_list_executions_flattened_error_async(): ) -def test_list_executions_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_executions_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: @@ -6052,8 +6098,10 @@ def test_list_executions_pager(): assert all(isinstance(i, execution.Execution) for i in results) -def test_list_executions_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_executions_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: @@ -6162,9 +6210,10 @@ async def test_list_executions_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_execution( - transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.UpdateExecutionRequest, dict,] +) +def test_update_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6203,10 +6252,6 @@ def test_update_execution( assert response.description == "description_value" -def test_update_execution_from_dict(): - test_update_execution(request_type=dict) - - def test_update_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6417,9 +6462,10 @@ async def test_update_execution_flattened_error_async(): ) -def test_delete_execution( - transport: str = "grpc", request_type=metadata_service.DeleteExecutionRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.DeleteExecutionRequest, dict,] +) +def test_delete_execution(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6443,10 +6489,6 @@ def test_delete_execution( assert isinstance(response, future.Future) -def test_delete_execution_from_dict(): - test_delete_execution(request_type=dict) - - def test_delete_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6622,9 +6664,10 @@ async def test_delete_execution_flattened_error_async(): ) -def test_purge_executions( - transport: str = "grpc", request_type=metadata_service.PurgeExecutionsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.PurgeExecutionsRequest, dict,] +) +def test_purge_executions(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6648,10 +6691,6 @@ def test_purge_executions( assert isinstance(response, future.Future) -def test_purge_executions_from_dict(): - test_purge_executions(request_type=dict) - - def test_purge_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6827,9 +6866,10 @@ async def test_purge_executions_flattened_error_async(): ) -def test_add_execution_events( - transport: str = "grpc", request_type=metadata_service.AddExecutionEventsRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.AddExecutionEventsRequest, dict,] +) +def test_add_execution_events(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6855,10 +6895,6 @@ def test_add_execution_events( assert isinstance(response, metadata_service.AddExecutionEventsResponse) -def test_add_execution_events_from_dict(): - test_add_execution_events(request_type=dict) - - def test_add_execution_events_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7062,10 +7098,10 @@ async def test_add_execution_events_flattened_error_async(): ) -def test_query_execution_inputs_and_outputs( - transport: str = "grpc", - request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryExecutionInputsAndOutputsRequest, dict,] +) +def test_query_execution_inputs_and_outputs(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7091,10 +7127,6 @@ def test_query_execution_inputs_and_outputs( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_execution_inputs_and_outputs_from_dict(): - test_query_execution_inputs_and_outputs(request_type=dict) - - def test_query_execution_inputs_and_outputs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7286,9 +7318,10 @@ async def test_query_execution_inputs_and_outputs_flattened_error_async(): ) -def test_create_metadata_schema( - transport: str = "grpc", request_type=metadata_service.CreateMetadataSchemaRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.CreateMetadataSchemaRequest, dict,] +) +def test_create_metadata_schema(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7328,10 +7361,6 @@ def test_create_metadata_schema( assert response.description == "description_value" -def test_create_metadata_schema_from_dict(): - test_create_metadata_schema(request_type=dict) - - def test_create_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7559,9 +7588,10 @@ async def test_create_metadata_schema_flattened_error_async(): ) -def test_get_metadata_schema( - transport: str = "grpc", request_type=metadata_service.GetMetadataSchemaRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.GetMetadataSchemaRequest, dict,] +) +def test_get_metadata_schema(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7601,10 +7631,6 @@ def test_get_metadata_schema( assert response.description == "description_value" -def test_get_metadata_schema_from_dict(): - test_get_metadata_schema(request_type=dict) - - def test_get_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7806,9 +7832,10 @@ async def test_get_metadata_schema_flattened_error_async(): ) -def test_list_metadata_schemas( - transport: str = "grpc", request_type=metadata_service.ListMetadataSchemasRequest -): +@pytest.mark.parametrize( + "request_type", [metadata_service.ListMetadataSchemasRequest, dict,] +) +def test_list_metadata_schemas(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7837,10 +7864,6 @@ def test_list_metadata_schemas( assert response.next_page_token == "next_page_token_value" -def test_list_metadata_schemas_from_dict(): - test_list_metadata_schemas(request_type=dict) - - def test_list_metadata_schemas_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8031,8 +8054,10 @@ async def test_list_metadata_schemas_flattened_error_async(): ) -def test_list_metadata_schemas_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_schemas_pager(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8077,8 +8102,10 @@ def test_list_metadata_schemas_pager(): assert all(isinstance(i, metadata_schema.MetadataSchema) for i in results) -def test_list_metadata_schemas_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_metadata_schemas_pages(transport_name: str = "grpc"): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8205,10 +8232,10 @@ async def test_list_metadata_schemas_async_pages(): assert page_.raw_page.next_page_token == token -def test_query_artifact_lineage_subgraph( - transport: str = "grpc", - request_type=metadata_service.QueryArtifactLineageSubgraphRequest, -): +@pytest.mark.parametrize( + "request_type", [metadata_service.QueryArtifactLineageSubgraphRequest, dict,] +) +def test_query_artifact_lineage_subgraph(request_type, transport: str = "grpc"): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -8234,10 +8261,6 @@ def test_query_artifact_lineage_subgraph( assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_query_artifact_lineage_subgraph_from_dict(): - test_query_artifact_lineage_subgraph(request_type=dict) - - def test_query_artifact_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8449,6 +8472,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MetadataServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MetadataServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.MetadataServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -9130,7 +9170,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -9195,3 +9235,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 176dc81e619..7f4b3a2060a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -256,20 +257,20 @@ def test_migration_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -338,7 +339,7 @@ def test_migration_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -415,6 +416,87 @@ def test_migration_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -433,7 +515,7 @@ def test_migration_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -447,24 +529,31 @@ def test_migration_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_migration_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -476,6 +565,35 @@ def test_migration_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_migration_service_client_client_options_from_dict(): with mock.patch( @@ -497,10 +615,10 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources( - transport: str = "grpc", - request_type=migration_service.SearchMigratableResourcesRequest, -): +@pytest.mark.parametrize( + "request_type", [migration_service.SearchMigratableResourcesRequest, dict,] +) +def test_search_migratable_resources(request_type, transport: str = "grpc"): client = MigrationServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -529,10 +647,6 @@ def test_search_migratable_resources( assert response.next_page_token == "next_page_token_value" -def test_search_migratable_resources_from_dict(): - test_search_migratable_resources(request_type=dict) - - def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -723,8 +837,10 @@ async def test_search_migratable_resources_flattened_error_async(): ) -def test_search_migratable_resources_pager(): - client = MigrationServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_migratable_resources_pager(transport_name: str = "grpc"): + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -771,8 +887,10 @@ def test_search_migratable_resources_pager(): ) -def test_search_migratable_resources_pages(): - client = MigrationServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_search_migratable_resources_pages(transport_name: str = "grpc"): + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -901,9 +1019,10 @@ async def test_search_migratable_resources_async_pages(): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources( - transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest -): +@pytest.mark.parametrize( + "request_type", [migration_service.BatchMigrateResourcesRequest, dict,] +) +def test_batch_migrate_resources(request_type, transport: str = "grpc"): client = MigrationServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -929,10 +1048,6 @@ def test_batch_migrate_resources( assert isinstance(response, future.Future) -def test_batch_migrate_resources_from_dict(): - test_batch_migrate_resources(request_type=dict) - - def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1192,6 +1307,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MigrationServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.MigrationServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1626,20 +1758,18 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1649,18 +1779,20 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + project = "scallop" + location = "abalone" + dataset = "squid" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", + "project": "clam", + "location": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1862,7 +1994,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1927,3 +2059,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index 31bb04d2253..0e20af1d5fc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -257,20 +258,20 @@ def test_model_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -327,7 +328,7 @@ def test_model_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -404,6 +405,83 @@ def test_model_service_client_mtls_env_auto( ) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -422,7 +500,7 @@ def test_model_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -436,24 +514,31 @@ def test_model_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -465,6 +550,35 @@ def test_model_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_model_service_client_client_options_from_dict(): with mock.patch( @@ -484,9 +598,8 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.UploadModelRequest, dict,]) +def test_upload_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -510,10 +623,6 @@ def test_upload_model( assert isinstance(response, future.Future) -def test_upload_model_from_dict(): - test_upload_model(request_type=dict) - - def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -696,7 +805,8 @@ async def test_upload_model_flattened_error_async(): ) -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): +@pytest.mark.parametrize("request_type", [model_service.GetModelRequest, dict,]) +def test_get_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -749,10 +859,6 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR assert response.etag == "etag_value" -def test_get_model_from_dict(): - test_get_model(request_type=dict) - - def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -950,9 +1056,8 @@ async def test_get_model_flattened_error_async(): ) -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): +@pytest.mark.parametrize("request_type", [model_service.ListModelsRequest, dict,]) +def test_list_models(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -979,10 +1084,6 @@ def test_list_models( assert response.next_page_token == "next_page_token_value" -def test_list_models_from_dict(): - test_list_models(request_type=dict) - - def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1152,8 +1253,10 @@ async def test_list_models_flattened_error_async(): ) -def test_list_models_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: @@ -1184,8 +1287,10 @@ def test_list_models_pager(): assert all(isinstance(i, model.Model) for i in results) -def test_list_models_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: @@ -1266,9 +1371,8 @@ async def test_list_models_async_pages(): assert page_.raw_page.next_page_token == token -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.UpdateModelRequest, dict,]) +def test_update_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1321,10 +1425,6 @@ def test_update_model( assert response.etag == "etag_value" -def test_update_model_from_dict(): - test_update_model(request_type=dict) - - def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1538,9 +1638,8 @@ async def test_update_model_flattened_error_async(): ) -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.DeleteModelRequest, dict,]) +def test_delete_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1564,10 +1663,6 @@ def test_delete_model( assert isinstance(response, future.Future) -def test_delete_model_from_dict(): - test_delete_model(request_type=dict) - - def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1736,9 +1831,8 @@ async def test_delete_model_flattened_error_async(): ) -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): +@pytest.mark.parametrize("request_type", [model_service.ExportModelRequest, dict,]) +def test_export_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1762,10 +1856,6 @@ def test_export_model( assert isinstance(response, future.Future) -def test_export_model_from_dict(): - test_export_model(request_type=dict) - - def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1962,9 +2052,10 @@ async def test_export_model_flattened_error_async(): ) -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.GetModelEvaluationRequest, dict,] +) +def test_get_model_evaluation(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1997,10 +2088,6 @@ def test_get_model_evaluation( assert response.slice_dimensions == ["slice_dimensions_value"] -def test_get_model_evaluation_from_dict(): - test_get_model_evaluation(request_type=dict) - - def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2189,9 +2276,10 @@ async def test_get_model_evaluation_flattened_error_async(): ) -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.ListModelEvaluationsRequest, dict,] +) +def test_list_model_evaluations(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2220,10 +2308,6 @@ def test_list_model_evaluations( assert response.next_page_token == "next_page_token_value" -def test_list_model_evaluations_from_dict(): - test_list_model_evaluations(request_type=dict) - - def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2408,8 +2492,10 @@ async def test_list_model_evaluations_flattened_error_async(): ) -def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluations_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2454,8 +2540,10 @@ def test_list_model_evaluations_pager(): assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) -def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluations_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2578,9 +2666,10 @@ async def test_list_model_evaluations_async_pages(): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.GetModelEvaluationSliceRequest, dict,] +) +def test_get_model_evaluation_slice(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2610,10 +2699,6 @@ def test_get_model_evaluation_slice( assert response.metrics_schema_uri == "metrics_schema_uri_value" -def test_get_model_evaluation_slice_from_dict(): - test_get_model_evaluation_slice(request_type=dict) - - def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2799,9 +2884,10 @@ async def test_get_model_evaluation_slice_flattened_error_async(): ) -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): +@pytest.mark.parametrize( + "request_type", [model_service.ListModelEvaluationSlicesRequest, dict,] +) +def test_list_model_evaluation_slices(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2830,10 +2916,6 @@ def test_list_model_evaluation_slices( assert response.next_page_token == "next_page_token_value" -def test_list_model_evaluation_slices_from_dict(): - test_list_model_evaluation_slices(request_type=dict) - - def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3018,8 +3100,10 @@ async def test_list_model_evaluation_slices_flattened_error_async(): ) -def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluation_slices_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3068,8 +3152,10 @@ def test_list_model_evaluation_slices_pager(): ) -def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_model_evaluation_slices_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3223,6 +3309,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.ModelServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3858,7 +3961,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3923,3 +4026,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 2f587a9d04b..5d42510a2ac 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -278,20 +279,20 @@ def test_pipeline_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -360,7 +361,7 @@ def test_pipeline_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -437,6 +438,87 @@ def test_pipeline_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -455,7 +537,7 @@ def test_pipeline_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -469,24 +551,31 @@ def test_pipeline_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -498,6 +587,35 @@ def test_pipeline_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_pipeline_service_client_client_options_from_dict(): with mock.patch( @@ -519,9 +637,10 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CreateTrainingPipelineRequest, dict,] +) +def test_create_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -556,10 +675,6 @@ def test_create_training_pipeline( assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED -def test_create_training_pipeline_from_dict(): - test_create_training_pipeline(request_type=dict) - - def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -772,9 +887,10 @@ async def test_create_training_pipeline_flattened_error_async(): ) -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.GetTrainingPipelineRequest, dict,] +) +def test_get_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -809,10 +925,6 @@ def test_get_training_pipeline( assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED -def test_get_training_pipeline_from_dict(): - test_get_training_pipeline(request_type=dict) - - def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1009,9 +1121,10 @@ async def test_get_training_pipeline_flattened_error_async(): ) -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.ListTrainingPipelinesRequest, dict,] +) +def test_list_training_pipelines(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1040,10 +1153,6 @@ def test_list_training_pipelines( assert response.next_page_token == "next_page_token_value" -def test_list_training_pipelines_from_dict(): - test_list_training_pipelines(request_type=dict) - - def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1234,8 +1343,10 @@ async def test_list_training_pipelines_flattened_error_async(): ) -def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_training_pipelines_pager(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1280,8 +1391,10 @@ def test_list_training_pipelines_pager(): assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) -def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_training_pipelines_pages(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1408,9 +1521,10 @@ async def test_list_training_pipelines_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.DeleteTrainingPipelineRequest, dict,] +) +def test_delete_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1436,10 +1550,6 @@ def test_delete_training_pipeline( assert isinstance(response, future.Future) -def test_delete_training_pipeline_from_dict(): - test_delete_training_pipeline(request_type=dict) - - def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1627,9 +1737,10 @@ async def test_delete_training_pipeline_flattened_error_async(): ) -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CancelTrainingPipelineRequest, dict,] +) +def test_cancel_training_pipeline(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1655,10 +1766,6 @@ def test_cancel_training_pipeline( assert response is None -def test_cancel_training_pipeline_from_dict(): - test_cancel_training_pipeline(request_type=dict) - - def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1840,9 +1947,10 @@ async def test_cancel_training_pipeline_flattened_error_async(): ) -def test_create_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.CreatePipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CreatePipelineJobRequest, dict,] +) +def test_create_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1879,10 +1987,6 @@ def test_create_pipeline_job( assert response.network == "network_value" -def test_create_pipeline_job_from_dict(): - test_create_pipeline_job(request_type=dict) - - def test_create_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2107,9 +2211,10 @@ async def test_create_pipeline_job_flattened_error_async(): ) -def test_get_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.GetPipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.GetPipelineJobRequest, dict,] +) +def test_get_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2144,10 +2249,6 @@ def test_get_pipeline_job( assert response.network == "network_value" -def test_get_pipeline_job_from_dict(): - test_get_pipeline_job(request_type=dict) - - def test_get_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2333,9 +2434,10 @@ async def test_get_pipeline_job_flattened_error_async(): ) -def test_list_pipeline_jobs( - transport: str = "grpc", request_type=pipeline_service.ListPipelineJobsRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.ListPipelineJobsRequest, dict,] +) +def test_list_pipeline_jobs(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2364,10 +2466,6 @@ def test_list_pipeline_jobs( assert response.next_page_token == "next_page_token_value" -def test_list_pipeline_jobs_from_dict(): - test_list_pipeline_jobs(request_type=dict) - - def test_list_pipeline_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2558,8 +2656,10 @@ async def test_list_pipeline_jobs_flattened_error_async(): ) -def test_list_pipeline_jobs_pager(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_pipeline_jobs_pager(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2600,8 +2700,10 @@ def test_list_pipeline_jobs_pager(): assert all(isinstance(i, pipeline_job.PipelineJob) for i in results) -def test_list_pipeline_jobs_pages(): - client = PipelineServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_pipeline_jobs_pages(transport_name: str = "grpc"): + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2716,9 +2818,10 @@ async def test_list_pipeline_jobs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.DeletePipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.DeletePipelineJobRequest, dict,] +) +def test_delete_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2744,10 +2847,6 @@ def test_delete_pipeline_job( assert isinstance(response, future.Future) -def test_delete_pipeline_job_from_dict(): - test_delete_pipeline_job(request_type=dict) - - def test_delete_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2935,9 +3034,10 @@ async def test_delete_pipeline_job_flattened_error_async(): ) -def test_cancel_pipeline_job( - transport: str = "grpc", request_type=pipeline_service.CancelPipelineJobRequest -): +@pytest.mark.parametrize( + "request_type", [pipeline_service.CancelPipelineJobRequest, dict,] +) +def test_cancel_pipeline_job(request_type, transport: str = "grpc"): client = PipelineServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2963,10 +3063,6 @@ def test_cancel_pipeline_job( assert response is None -def test_cancel_pipeline_job_from_dict(): - test_cancel_pipeline_job(request_type=dict) - - def test_cancel_pipeline_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3168,6 +3264,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PipelineServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.PipelineServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3915,7 +4028,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -3980,3 +4093,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index aa0f4c76b5a..fa04aed2a7e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -255,20 +255,20 @@ def test_prediction_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -337,7 +337,7 @@ def test_prediction_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -414,6 +414,87 @@ def test_prediction_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +def test_prediction_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -432,7 +513,7 @@ def test_prediction_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -446,24 +527,31 @@ def test_prediction_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_prediction_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -475,6 +563,35 @@ def test_prediction_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_prediction_service_client_client_options_from_dict(): with mock.patch( @@ -496,9 +613,8 @@ def test_prediction_service_client_client_options_from_dict(): ) -def test_predict( - transport: str = "grpc", request_type=prediction_service.PredictRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.PredictRequest, dict,]) +def test_predict(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -529,10 +645,6 @@ def test_predict( assert response.model_display_name == "model_display_name_value" -def test_predict_from_dict(): - test_predict(request_type=dict) - - def test_predict_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -673,9 +785,8 @@ async def test_predict_flattened_error_async(): ) -def test_raw_predict( - transport: str = "grpc", request_type=prediction_service.RawPredictRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.RawPredictRequest, dict,]) +def test_raw_predict(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -703,10 +814,6 @@ def test_raw_predict( assert response.data == b"data_blob" -def test_raw_predict_from_dict(): - test_raw_predict(request_type=dict) - - def test_raw_predict_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -899,9 +1006,8 @@ async def test_raw_predict_flattened_error_async(): ) -def test_explain( - transport: str = "grpc", request_type=prediction_service.ExplainRequest -): +@pytest.mark.parametrize("request_type", [prediction_service.ExplainRequest, dict,]) +def test_explain(request_type, transport: str = "grpc"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -928,10 +1034,6 @@ def test_explain( assert response.deployed_model_id == "deployed_model_id_value" -def test_explain_from_dict(): - test_explain(request_type=dict) - - def test_explain_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1090,6 +1192,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.PredictionServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1610,7 +1729,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -1675,3 +1794,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 490e6d5f97c..c55119cdd45 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -263,20 +264,20 @@ def test_specialist_pool_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -345,7 +346,7 @@ def test_specialist_pool_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -422,6 +423,87 @@ def test_specialist_pool_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -444,7 +526,7 @@ def test_specialist_pool_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -458,28 +540,31 @@ def test_specialist_pool_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ ( SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", + grpc_helpers, ), ( SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -491,6 +576,35 @@ def test_specialist_pool_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_specialist_pool_service_client_client_options_from_dict(): with mock.patch( @@ -512,10 +626,10 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.CreateSpecialistPoolRequest, dict,] +) +def test_create_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -541,10 +655,6 @@ def test_create_specialist_pool( assert isinstance(response, future.Future) -def test_create_specialist_pool_from_dict(): - test_create_specialist_pool(request_type=dict) - - def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -754,10 +864,10 @@ async def test_create_specialist_pool_flattened_error_async(): ) -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.GetSpecialistPoolRequest, dict,] +) +def test_get_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -796,10 +906,6 @@ def test_get_specialist_pool( assert response.specialist_worker_emails == ["specialist_worker_emails_value"] -def test_get_specialist_pool_from_dict(): - test_get_specialist_pool(request_type=dict) - - def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1006,10 +1112,10 @@ async def test_get_specialist_pool_flattened_error_async(): ) -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.ListSpecialistPoolsRequest, dict,] +) +def test_list_specialist_pools(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1038,10 +1144,6 @@ def test_list_specialist_pools( assert response.next_page_token == "next_page_token_value" -def test_list_specialist_pools_from_dict(): - test_list_specialist_pools(request_type=dict) - - def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1238,9 +1340,9 @@ async def test_list_specialist_pools_flattened_error_async(): ) -def test_list_specialist_pools_pager(): +def test_list_specialist_pools_pager(transport_name: str = "grpc"): client = SpecialistPoolServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1286,9 +1388,9 @@ def test_list_specialist_pools_pager(): assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) -def test_list_specialist_pools_pages(): +def test_list_specialist_pools_pages(transport_name: str = "grpc"): client = SpecialistPoolServiceClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1416,10 +1518,10 @@ async def test_list_specialist_pools_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.DeleteSpecialistPoolRequest, dict,] +) +def test_delete_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1445,10 +1547,6 @@ def test_delete_specialist_pool( assert isinstance(response, future.Future) -def test_delete_specialist_pool_from_dict(): - test_delete_specialist_pool(request_type=dict) - - def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1642,10 +1740,10 @@ async def test_delete_specialist_pool_flattened_error_async(): ) -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +@pytest.mark.parametrize( + "request_type", [specialist_pool_service.UpdateSpecialistPoolRequest, dict,] +) +def test_update_specialist_pool(request_type, transport: str = "grpc"): client = SpecialistPoolServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1671,10 +1769,6 @@ def test_update_specialist_pool( assert isinstance(response, future.Future) -def test_update_specialist_pool_from_dict(): - test_update_specialist_pool(request_type=dict) - - def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1910,6 +2004,25 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options=options, transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.SpecialistPoolServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2447,7 +2560,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -2512,3 +2625,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py index 332bf28c9a6..2fb4c03afc4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -272,20 +273,20 @@ def test_tensorboard_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -354,7 +355,7 @@ def test_tensorboard_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -431,6 +432,87 @@ def test_tensorboard_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [TensorboardServiceClient, TensorboardServiceAsyncClient] +) +@mock.patch.object( + TensorboardServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceClient), +) +@mock.patch.object( + TensorboardServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceAsyncClient), +) +def test_tensorboard_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -449,7 +531,7 @@ def test_tensorboard_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -463,24 +545,31 @@ def test_tensorboard_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport, "grpc"), + ( + TensorboardServiceClient, + transports.TensorboardServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( TensorboardServiceAsyncClient, transports.TensorboardServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_tensorboard_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -492,6 +581,35 @@ def test_tensorboard_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_tensorboard_service_client_client_options_from_dict(): with mock.patch( @@ -513,9 +631,10 @@ def test_tensorboard_service_client_client_options_from_dict(): ) -def test_create_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.CreateTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardRequest, dict,] +) +def test_create_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -541,10 +660,6 @@ def test_create_tensorboard( assert isinstance(response, future.Future) -def test_create_tensorboard_from_dict(): - test_create_tensorboard(request_type=dict) - - def test_create_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -754,9 +869,10 @@ async def test_create_tensorboard_flattened_error_async(): ) -def test_get_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardRequest, dict,] +) +def test_get_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -793,10 +909,6 @@ def test_get_tensorboard( assert response.etag == "etag_value" -def test_get_tensorboard_from_dict(): - test_get_tensorboard(request_type=dict) - - def test_get_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -991,9 +1103,10 @@ async def test_get_tensorboard_flattened_error_async(): ) -def test_update_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.UpdateTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardRequest, dict,] +) +def test_update_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1019,10 +1132,6 @@ def test_update_tensorboard( assert isinstance(response, future.Future) -def test_update_tensorboard_from_dict(): - test_update_tensorboard(request_type=dict) - - def test_update_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1236,9 +1345,10 @@ async def test_update_tensorboard_flattened_error_async(): ) -def test_list_tensorboards( - transport: str = "grpc", request_type=tensorboard_service.ListTensorboardsRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardsRequest, dict,] +) +def test_list_tensorboards(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1267,10 +1377,6 @@ def test_list_tensorboards( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboards_from_dict(): - test_list_tensorboards(request_type=dict) - - def test_list_tensorboards_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1467,8 +1573,10 @@ async def test_list_tensorboards_flattened_error_async(): ) -def test_list_tensorboards_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboards_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1509,8 +1617,10 @@ def test_list_tensorboards_pager(): assert all(isinstance(i, tensorboard.Tensorboard) for i in results) -def test_list_tensorboards_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboards_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1625,9 +1735,10 @@ async def test_list_tensorboards_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard( - transport: str = "grpc", request_type=tensorboard_service.DeleteTensorboardRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardRequest, dict,] +) +def test_delete_tensorboard(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1653,10 +1764,6 @@ def test_delete_tensorboard( assert isinstance(response, future.Future) -def test_delete_tensorboard_from_dict(): - test_delete_tensorboard(request_type=dict) - - def test_delete_tensorboard_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1850,10 +1957,10 @@ async def test_delete_tensorboard_flattened_error_async(): ) -def test_create_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardExperimentRequest, dict,] +) +def test_create_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1890,10 +1997,6 @@ def test_create_tensorboard_experiment( assert response.source == "source_value" -def test_create_tensorboard_experiment_from_dict(): - test_create_tensorboard_experiment(request_type=dict) - - def test_create_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2132,10 +2235,10 @@ async def test_create_tensorboard_experiment_flattened_error_async(): ) -def test_get_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.GetTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardExperimentRequest, dict,] +) +def test_get_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2172,10 +2275,6 @@ def test_get_tensorboard_experiment( assert response.source == "source_value" -def test_get_tensorboard_experiment_from_dict(): - test_get_tensorboard_experiment(request_type=dict) - - def test_get_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2380,10 +2479,10 @@ async def test_get_tensorboard_experiment_flattened_error_async(): ) -def test_update_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardExperimentRequest, dict,] +) +def test_update_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2420,10 +2519,6 @@ def test_update_tensorboard_experiment( assert response.source == "source_value" -def test_update_tensorboard_experiment_from_dict(): - test_update_tensorboard_experiment(request_type=dict) - - def test_update_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2658,10 +2753,10 @@ async def test_update_tensorboard_experiment_flattened_error_async(): ) -def test_list_tensorboard_experiments( - transport: str = "grpc", - request_type=tensorboard_service.ListTensorboardExperimentsRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardExperimentsRequest, dict,] +) +def test_list_tensorboard_experiments(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2690,10 +2785,6 @@ def test_list_tensorboard_experiments( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_experiments_from_dict(): - test_list_tensorboard_experiments(request_type=dict) - - def test_list_tensorboard_experiments_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2892,8 +2983,10 @@ async def test_list_tensorboard_experiments_flattened_error_async(): ) -def test_list_tensorboard_experiments_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_experiments_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -2942,8 +3035,10 @@ def test_list_tensorboard_experiments_pager(): ) -def test_list_tensorboard_experiments_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_experiments_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3081,10 +3176,10 @@ async def test_list_tensorboard_experiments_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_experiment( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardExperimentRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardExperimentRequest, dict,] +) +def test_delete_tensorboard_experiment(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3110,10 +3205,6 @@ def test_delete_tensorboard_experiment( assert isinstance(response, future.Future) -def test_delete_tensorboard_experiment_from_dict(): - test_delete_tensorboard_experiment(request_type=dict) - - def test_delete_tensorboard_experiment_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3307,10 +3398,10 @@ async def test_delete_tensorboard_experiment_flattened_error_async(): ) -def test_create_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardRunRequest, dict,] +) +def test_create_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3345,10 +3436,6 @@ def test_create_tensorboard_run( assert response.etag == "etag_value" -def test_create_tensorboard_run_from_dict(): - test_create_tensorboard_run(request_type=dict) - - def test_create_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3577,10 +3664,10 @@ async def test_create_tensorboard_run_flattened_error_async(): ) -def test_batch_create_tensorboard_runs( - transport: str = "grpc", - request_type=tensorboard_service.BatchCreateTensorboardRunsRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.BatchCreateTensorboardRunsRequest, dict,] +) +def test_batch_create_tensorboard_runs(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3606,10 +3693,6 @@ def test_batch_create_tensorboard_runs( assert isinstance(response, tensorboard_service.BatchCreateTensorboardRunsResponse) -def test_batch_create_tensorboard_runs_from_dict(): - test_batch_create_tensorboard_runs(request_type=dict) - - def test_batch_create_tensorboard_runs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3831,9 +3914,10 @@ async def test_batch_create_tensorboard_runs_flattened_error_async(): ) -def test_get_tensorboard_run( - transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRunRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardRunRequest, dict,] +) +def test_get_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3868,10 +3952,6 @@ def test_get_tensorboard_run( assert response.etag == "etag_value" -def test_get_tensorboard_run_from_dict(): - test_get_tensorboard_run(request_type=dict) - - def test_get_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4074,10 +4154,10 @@ async def test_get_tensorboard_run_flattened_error_async(): ) -def test_update_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardRunRequest, dict,] +) +def test_update_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4112,10 +4192,6 @@ def test_update_tensorboard_run( assert response.etag == "etag_value" -def test_update_tensorboard_run_from_dict(): - test_update_tensorboard_run(request_type=dict) - - def test_update_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4340,9 +4416,10 @@ async def test_update_tensorboard_run_flattened_error_async(): ) -def test_list_tensorboard_runs( - transport: str = "grpc", request_type=tensorboard_service.ListTensorboardRunsRequest -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardRunsRequest, dict,] +) +def test_list_tensorboard_runs(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4371,10 +4448,6 @@ def test_list_tensorboard_runs( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_runs_from_dict(): - test_list_tensorboard_runs(request_type=dict) - - def test_list_tensorboard_runs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4571,8 +4644,10 @@ async def test_list_tensorboard_runs_flattened_error_async(): ) -def test_list_tensorboard_runs_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_runs_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -4617,8 +4692,10 @@ def test_list_tensorboard_runs_pager(): assert all(isinstance(i, tensorboard_run.TensorboardRun) for i in results) -def test_list_tensorboard_runs_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_runs_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -4745,10 +4822,10 @@ async def test_list_tensorboard_runs_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_run( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardRunRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardRunRequest, dict,] +) +def test_delete_tensorboard_run(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -4774,10 +4851,6 @@ def test_delete_tensorboard_run( assert isinstance(response, future.Future) -def test_delete_tensorboard_run_from_dict(): - test_delete_tensorboard_run(request_type=dict) - - def test_delete_tensorboard_run_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -4971,10 +5044,10 @@ async def test_delete_tensorboard_run_flattened_error_async(): ) -def test_batch_create_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, dict,] +) +def test_batch_create_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5004,10 +5077,6 @@ def test_batch_create_tensorboard_time_series( ) -def test_batch_create_tensorboard_time_series_from_dict(): - test_batch_create_tensorboard_time_series(request_type=dict) - - def test_batch_create_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5249,10 +5318,10 @@ async def test_batch_create_tensorboard_time_series_flattened_error_async(): ) -def test_create_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.CreateTensorboardTimeSeriesRequest, dict,] +) +def test_create_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5296,10 +5365,6 @@ def test_create_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_create_tensorboard_time_series_from_dict(): - test_create_tensorboard_time_series(request_type=dict) - - def test_create_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5535,10 +5600,10 @@ async def test_create_tensorboard_time_series_flattened_error_async(): ) -def test_get_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.GetTensorboardTimeSeriesRequest, dict,] +) +def test_get_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5582,10 +5647,6 @@ def test_get_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_get_tensorboard_time_series_from_dict(): - test_get_tensorboard_time_series(request_type=dict) - - def test_get_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -5797,10 +5858,10 @@ async def test_get_tensorboard_time_series_flattened_error_async(): ) -def test_update_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.UpdateTensorboardTimeSeriesRequest, dict,] +) +def test_update_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5844,10 +5905,6 @@ def test_update_tensorboard_time_series( assert response.plugin_data == b"plugin_data_blob" -def test_update_tensorboard_time_series_from_dict(): - test_update_tensorboard_time_series(request_type=dict) - - def test_update_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6089,10 +6146,10 @@ async def test_update_tensorboard_time_series_flattened_error_async(): ) -def test_list_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ListTensorboardTimeSeriesRequest, dict,] +) +def test_list_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6121,10 +6178,6 @@ def test_list_tensorboard_time_series( assert response.next_page_token == "next_page_token_value" -def test_list_tensorboard_time_series_from_dict(): - test_list_tensorboard_time_series(request_type=dict) - - def test_list_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6323,8 +6376,10 @@ async def test_list_tensorboard_time_series_flattened_error_async(): ) -def test_list_tensorboard_time_series_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_time_series_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6374,8 +6429,10 @@ def test_list_tensorboard_time_series_pager(): ) -def test_list_tensorboard_time_series_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_tensorboard_time_series_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -6513,10 +6570,10 @@ async def test_list_tensorboard_time_series_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_tensorboard_time_series( - transport: str = "grpc", - request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.DeleteTensorboardTimeSeriesRequest, dict,] +) +def test_delete_tensorboard_time_series(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6542,10 +6599,6 @@ def test_delete_tensorboard_time_series( assert isinstance(response, future.Future) -def test_delete_tensorboard_time_series_from_dict(): - test_delete_tensorboard_time_series(request_type=dict) - - def test_delete_tensorboard_time_series_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6739,10 +6792,11 @@ async def test_delete_tensorboard_time_series_flattened_error_async(): ) -def test_batch_read_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", + [tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, dict,], +) +def test_batch_read_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -6774,10 +6828,6 @@ def test_batch_read_tensorboard_time_series_data( ) -def test_batch_read_tensorboard_time_series_data_from_dict(): - test_batch_read_tensorboard_time_series_data(request_type=dict) - - def test_batch_read_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -6987,10 +7037,10 @@ async def test_batch_read_tensorboard_time_series_data_flattened_error_async(): ) -def test_read_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ReadTensorboardTimeSeriesDataRequest, dict,] +) +def test_read_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7018,10 +7068,6 @@ def test_read_tensorboard_time_series_data( ) -def test_read_tensorboard_time_series_data_from_dict(): - test_read_tensorboard_time_series_data(request_type=dict) - - def test_read_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7229,10 +7275,10 @@ async def test_read_tensorboard_time_series_data_flattened_error_async(): ) -def test_read_tensorboard_blob_data( - transport: str = "grpc", - request_type=tensorboard_service.ReadTensorboardBlobDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ReadTensorboardBlobDataRequest, dict,] +) +def test_read_tensorboard_blob_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7261,10 +7307,6 @@ def test_read_tensorboard_blob_data( assert isinstance(message, tensorboard_service.ReadTensorboardBlobDataResponse) -def test_read_tensorboard_blob_data_from_dict(): - test_read_tensorboard_blob_data(request_type=dict) - - def test_read_tensorboard_blob_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7469,10 +7511,10 @@ async def test_read_tensorboard_blob_data_flattened_error_async(): ) -def test_write_tensorboard_experiment_data( - transport: str = "grpc", - request_type=tensorboard_service.WriteTensorboardExperimentDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.WriteTensorboardExperimentDataRequest, dict,] +) +def test_write_tensorboard_experiment_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7500,10 +7542,6 @@ def test_write_tensorboard_experiment_data( ) -def test_write_tensorboard_experiment_data_from_dict(): - test_write_tensorboard_experiment_data(request_type=dict) - - def test_write_tensorboard_experiment_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -7745,10 +7783,10 @@ async def test_write_tensorboard_experiment_data_flattened_error_async(): ) -def test_write_tensorboard_run_data( - transport: str = "grpc", - request_type=tensorboard_service.WriteTensorboardRunDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.WriteTensorboardRunDataRequest, dict,] +) +def test_write_tensorboard_run_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -7774,10 +7812,6 @@ def test_write_tensorboard_run_data( assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) -def test_write_tensorboard_run_data_from_dict(): - test_write_tensorboard_run_data(request_type=dict) - - def test_write_tensorboard_run_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8015,10 +8049,10 @@ async def test_write_tensorboard_run_data_flattened_error_async(): ) -def test_export_tensorboard_time_series_data( - transport: str = "grpc", - request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, -): +@pytest.mark.parametrize( + "request_type", [tensorboard_service.ExportTensorboardTimeSeriesDataRequest, dict,] +) +def test_export_tensorboard_time_series_data(request_type, transport: str = "grpc"): client = TensorboardServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -8047,10 +8081,6 @@ def test_export_tensorboard_time_series_data( assert response.next_page_token == "next_page_token_value" -def test_export_tensorboard_time_series_data_from_dict(): - test_export_tensorboard_time_series_data(request_type=dict) - - def test_export_tensorboard_time_series_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -8265,8 +8295,10 @@ async def test_export_tensorboard_time_series_data_flattened_error_async(): ) -def test_export_tensorboard_time_series_data_pager(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_export_tensorboard_time_series_data_pager(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8313,8 +8345,10 @@ def test_export_tensorboard_time_series_data_pager(): assert all(isinstance(i, tensorboard_data.TimeSeriesDataPoint) for i in results) -def test_export_tensorboard_time_series_data_pages(): - client = TensorboardServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_export_tensorboard_time_series_data_pages(transport_name: str = "grpc"): + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -8465,6 +8499,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TensorboardServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TensorboardServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.TensorboardServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -9123,7 +9174,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -9188,3 +9239,36 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index 06287c35d13..a0331ed3fc0 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template @@ -257,20 +258,20 @@ def test_vizier_service_client_client_options( # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): with pytest.raises(MutualTLSChannelError): - client = client_class() + client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): with pytest.raises(ValueError): - client = client_class() + client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -329,7 +330,7 @@ def test_vizier_service_client_mtls_env_auto( ) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) if use_client_cert_env == "false": expected_client_cert_source = None @@ -406,6 +407,87 @@ def test_vizier_service_client_mtls_env_auto( ) +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient] +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) +def test_vizier_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + @pytest.mark.parametrize( "client_class,transport_class,transport_name", [ @@ -424,7 +506,7 @@ def test_vizier_service_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, @@ -438,24 +520,31 @@ def test_vizier_service_client_client_options_scopes( @pytest.mark.parametrize( - "client_class,transport_class,transport_name", + "client_class,transport_class,transport_name,grpc_helpers", [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceClient, + transports.VizierServiceGrpcTransport, + "grpc", + grpc_helpers, + ), ( VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", + grpc_helpers_async, ), ], ) def test_vizier_service_client_client_options_credentials_file( - client_class, transport_class, transport_name + client_class, transport_class, transport_name, grpc_helpers ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(transport=transport_name, client_options=options) + client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", @@ -467,6 +556,35 @@ def test_vizier_service_client_client_options_credentials_file( always_use_jwt_access=True, ) + if "grpc" in transport_name: + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + def test_vizier_service_client_client_options_from_dict(): with mock.patch( @@ -488,9 +606,8 @@ def test_vizier_service_client_client_options_from_dict(): ) -def test_create_study( - transport: str = "grpc", request_type=vizier_service.CreateStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CreateStudyRequest, dict,]) +def test_create_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -523,10 +640,6 @@ def test_create_study( assert response.inactive_reason == "inactive_reason_value" -def test_create_study_from_dict(): - test_create_study(request_type=dict) - - def test_create_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -720,9 +833,8 @@ async def test_create_study_flattened_error_async(): ) -def test_get_study( - transport: str = "grpc", request_type=vizier_service.GetStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.GetStudyRequest, dict,]) +def test_get_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -755,10 +867,6 @@ def test_get_study( assert response.inactive_reason == "inactive_reason_value" -def test_get_study_from_dict(): - test_get_study(request_type=dict) - - def test_get_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -938,9 +1046,8 @@ async def test_get_study_flattened_error_async(): ) -def test_list_studies( - transport: str = "grpc", request_type=vizier_service.ListStudiesRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.ListStudiesRequest, dict,]) +def test_list_studies(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -967,10 +1074,6 @@ def test_list_studies( assert response.next_page_token == "next_page_token_value" -def test_list_studies_from_dict(): - test_list_studies(request_type=dict) - - def test_list_studies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1146,8 +1249,10 @@ async def test_list_studies_flattened_error_async(): ) -def test_list_studies_pager(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_studies_pager(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: @@ -1180,8 +1285,10 @@ def test_list_studies_pager(): assert all(isinstance(i, study.Study) for i in results) -def test_list_studies_pages(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_studies_pages(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: @@ -1268,9 +1375,8 @@ async def test_list_studies_async_pages(): assert page_.raw_page.next_page_token == token -def test_delete_study( - transport: str = "grpc", request_type=vizier_service.DeleteStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.DeleteStudyRequest, dict,]) +def test_delete_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1294,10 +1400,6 @@ def test_delete_study( assert response is None -def test_delete_study_from_dict(): - test_delete_study(request_type=dict) - - def test_delete_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1466,9 +1568,8 @@ async def test_delete_study_flattened_error_async(): ) -def test_lookup_study( - transport: str = "grpc", request_type=vizier_service.LookupStudyRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.LookupStudyRequest, dict,]) +def test_lookup_study(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1501,10 +1602,6 @@ def test_lookup_study( assert response.inactive_reason == "inactive_reason_value" -def test_lookup_study_from_dict(): - test_lookup_study(request_type=dict) - - def test_lookup_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1684,9 +1781,8 @@ async def test_lookup_study_flattened_error_async(): ) -def test_suggest_trials( - transport: str = "grpc", request_type=vizier_service.SuggestTrialsRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.SuggestTrialsRequest, dict,]) +def test_suggest_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1710,10 +1806,6 @@ def test_suggest_trials( assert isinstance(response, future.Future) -def test_suggest_trials_from_dict(): - test_suggest_trials(request_type=dict) - - def test_suggest_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -1816,9 +1908,8 @@ async def test_suggest_trials_field_headers_async(): assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_create_trial( - transport: str = "grpc", request_type=vizier_service.CreateTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CreateTrialRequest, dict,]) +def test_create_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -1855,10 +1946,6 @@ def test_create_trial( assert response.custom_job == "custom_job_value" -def test_create_trial_from_dict(): - test_create_trial(request_type=dict) - - def test_create_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2056,9 +2143,8 @@ async def test_create_trial_flattened_error_async(): ) -def test_get_trial( - transport: str = "grpc", request_type=vizier_service.GetTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.GetTrialRequest, dict,]) +def test_get_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2095,10 +2181,6 @@ def test_get_trial( assert response.custom_job == "custom_job_value" -def test_get_trial_from_dict(): - test_get_trial(request_type=dict) - - def test_get_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2282,9 +2364,8 @@ async def test_get_trial_flattened_error_async(): ) -def test_list_trials( - transport: str = "grpc", request_type=vizier_service.ListTrialsRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.ListTrialsRequest, dict,]) +def test_list_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2311,10 +2392,6 @@ def test_list_trials( assert response.next_page_token == "next_page_token_value" -def test_list_trials_from_dict(): - test_list_trials(request_type=dict) - - def test_list_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2490,8 +2567,10 @@ async def test_list_trials_flattened_error_async(): ) -def test_list_trials_pager(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_trials_pager(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: @@ -2522,8 +2601,10 @@ def test_list_trials_pager(): assert all(isinstance(i, study.Trial) for i in results) -def test_list_trials_pages(): - client = VizierServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_list_trials_pages(transport_name: str = "grpc"): + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials, transport=transport_name, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: @@ -2604,9 +2685,10 @@ async def test_list_trials_async_pages(): assert page_.raw_page.next_page_token == token -def test_add_trial_measurement( - transport: str = "grpc", request_type=vizier_service.AddTrialMeasurementRequest -): +@pytest.mark.parametrize( + "request_type", [vizier_service.AddTrialMeasurementRequest, dict,] +) +def test_add_trial_measurement(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2645,10 +2727,6 @@ def test_add_trial_measurement( assert response.custom_job == "custom_job_value" -def test_add_trial_measurement_from_dict(): - test_add_trial_measurement(request_type=dict) - - def test_add_trial_measurement_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2771,9 +2849,8 @@ async def test_add_trial_measurement_field_headers_async(): assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_complete_trial( - transport: str = "grpc", request_type=vizier_service.CompleteTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.CompleteTrialRequest, dict,]) +def test_complete_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2810,10 +2887,6 @@ def test_complete_trial( assert response.custom_job == "custom_job_value" -def test_complete_trial_from_dict(): - test_complete_trial(request_type=dict) - - def test_complete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -2927,9 +3000,8 @@ async def test_complete_trial_field_headers_async(): assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_delete_trial( - transport: str = "grpc", request_type=vizier_service.DeleteTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.DeleteTrialRequest, dict,]) +def test_delete_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -2953,10 +3025,6 @@ def test_delete_trial( assert response is None -def test_delete_trial_from_dict(): - test_delete_trial(request_type=dict) - - def test_delete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3125,10 +3193,10 @@ async def test_delete_trial_flattened_error_async(): ) -def test_check_trial_early_stopping_state( - transport: str = "grpc", - request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, -): +@pytest.mark.parametrize( + "request_type", [vizier_service.CheckTrialEarlyStoppingStateRequest, dict,] +) +def test_check_trial_early_stopping_state(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3154,10 +3222,6 @@ def test_check_trial_early_stopping_state( assert isinstance(response, future.Future) -def test_check_trial_early_stopping_state_from_dict(): - test_check_trial_early_stopping_state(request_type=dict) - - def test_check_trial_early_stopping_state_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3269,9 +3333,8 @@ async def test_check_trial_early_stopping_state_field_headers_async(): assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_stop_trial( - transport: str = "grpc", request_type=vizier_service.StopTrialRequest -): +@pytest.mark.parametrize("request_type", [vizier_service.StopTrialRequest, dict,]) +def test_stop_trial(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3308,10 +3371,6 @@ def test_stop_trial( assert response.custom_job == "custom_job_value" -def test_stop_trial_from_dict(): - test_stop_trial(request_type=dict) - - def test_stop_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3425,9 +3484,10 @@ async def test_stop_trial_field_headers_async(): assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_list_optimal_trials( - transport: str = "grpc", request_type=vizier_service.ListOptimalTrialsRequest -): +@pytest.mark.parametrize( + "request_type", [vizier_service.ListOptimalTrialsRequest, dict,] +) +def test_list_optimal_trials(request_type, transport: str = "grpc"): client = VizierServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3453,10 +3513,6 @@ def test_list_optimal_trials( assert isinstance(response, vizier_service.ListOptimalTrialsResponse) -def test_list_optimal_trials_from_dict(): - test_list_optimal_trials(request_type=dict) - - def test_list_optimal_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @@ -3664,6 +3720,23 @@ def test_credentials_transport_error(): transport=transport, ) + # It is an error to provide an api_key and a transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = VizierServiceClient(client_options=options, transport=transport,) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = VizierServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + # It is an error to provide scopes and a transport instance. transport = transports.VizierServiceGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -4251,7 +4324,7 @@ def test_parse_common_location_path(): assert expected == actual -def test_client_withDEFAULT_CLIENT_INFO(): +def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() with mock.patch.object( @@ -4316,3 +4389,33 @@ def test_client_ctx(): with client: pass close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) From 7a7f0d45f3d08c93b11fcd2c5a265a8db4b0c890 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Mon, 31 Jan 2022 15:15:34 -0800 Subject: [PATCH 6/6] feat: add dedicated_resources to DeployedIndex in aiplatform v1beta1 index_endpoint.proto feat: add Scaling to OnlineServingConfig in aiplatform v1beta1 featurestore.proto chore: sort imports (#991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add dedicated_resources to DeployedIndex in aiplatform v1beta1 index_endpoint.proto feat: add Scaling to OnlineServingConfig in aiplatform v1beta1 featurestore.proto chore: sort imports PiperOrigin-RevId: 425395202 Source-Link: https://github.com/googleapis/googleapis/commit/e3bcc1ee4f7bdeeaad2c947dea7f4388b7d864e7 Source-Link: https://github.com/googleapis/googleapis-gen/commit/62beef78559c8bab47ecd36eed1aa6b678db6088 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiNjJiZWVmNzg1NTljOGJhYjQ3ZWNkMzZlZWQxYWE2YjY3OGRiNjA4OCJ9 * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md Co-authored-by: Owl Bot Co-authored-by: Yu-Han Liu --- .../services/vizier_service/async_client.py | 6 ---- .../services/vizier_service/client.py | 6 ---- .../aiplatform_v1beta1/types/custom_job.py | 1 + .../aiplatform_v1beta1/types/featurestore.py | 34 ++++++++++++++++--- .../types/index_endpoint.py | 12 +++++++ .../cloud/aiplatform_v1beta1/types/model.py | 10 +++--- .../types/model_deployment_monitoring_job.py | 5 +-- .../cloud/aiplatform_v1beta1/types/study.py | 3 +- .../test_index_endpoint_service.py | 1 + 9 files changed, 53 insertions(+), 25 deletions(-) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 4f4ff19fb18..c55267c36f1 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -255,9 +255,7 @@ async def create_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have @@ -328,9 +326,7 @@ async def get_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have @@ -548,9 +544,7 @@ async def lookup_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index f4b0dc15046..0d8d1b9700e 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -478,9 +478,7 @@ def create_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have @@ -551,9 +549,7 @@ def get_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have @@ -771,9 +767,7 @@ def lookup_study( Returns: google.cloud.aiplatform_v1beta1.types.Study: - LINT.IfChange A message representing a Study. - """ # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 7ec2a92c038..7d7edd9e37e 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -120,6 +120,7 @@ class CustomJob(proto.Message): class CustomJobSpec(proto.Message): r"""Represents the spec of a CustomJob. + Next Id: 14 Attributes: worker_pool_specs (Sequence[google.cloud.aiplatform_v1beta1.types.WorkerPoolSpec]): diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore.py b/google/cloud/aiplatform_v1beta1/types/featurestore.py index 5d19c1dee71..f28059c6e9b 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore.py @@ -81,13 +81,39 @@ class OnlineServingConfig(proto.Message): Attributes: fixed_node_count (int): - The number of nodes for each cluster. The - number of nodes will not scale automatically but - can be scaled manually by providing different - values when updating. + The number of nodes for each cluster. The number of nodes + will not scale automatically but can be scaled manually by + providing different values when updating. Only one of + ``fixed_node_count`` and ``scaling`` can be set. Setting one + will reset the other. + scaling (google.cloud.aiplatform_v1beta1.types.Featurestore.OnlineServingConfig.Scaling): + Online serving scaling configuration. Only one of + ``fixed_node_count`` and ``scaling`` can be set. Setting one + will reset the other. """ + class Scaling(proto.Message): + r"""Online serving scaling configuration. If min_node_count and + max_node_count are set to the same value, the cluster will be + configured with the fixed number of node (no auto-scaling). + + Attributes: + min_node_count (int): + Required. The minimum number of nodes to + scale down to. Must be greater than or equal to + 1. + max_node_count (int): + The maximum number of nodes to scale up to. Must be greater + or equal to min_node_count. + """ + + min_node_count = proto.Field(proto.INT32, number=1,) + max_node_count = proto.Field(proto.INT32, number=2,) + fixed_node_count = proto.Field(proto.INT32, number=2,) + scaling = proto.Field( + proto.MESSAGE, number=4, message="Featurestore.OnlineServingConfig.Scaling", + ) name = proto.Field(proto.STRING, number=1,) create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp_pb2.Timestamp,) diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py index c4239573754..4bb17029db9 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -171,6 +171,15 @@ class DeployedIndex(proto.Message): don't provide SLA when min_replica_count=1). If max_replica_count is not set, the default value is min_replica_count. The max allowed replica count is 1000. + dedicated_resources (google.cloud.aiplatform_v1beta1.types.DedicatedResources): + Optional. A description of resources that are dedicated to + the DeployedIndex, and that need a higher degree of manual + configuration. If min_replica_count is not set, the default + value is 2 (we don't provide SLA when min_replica_count=1). + If max_replica_count is not set, the default value is + min_replica_count. The max allowed replica count is 1000. + + Available machine types: n1-standard-16 n1-standard-32 enable_access_logging (bool): Optional. If true, private endpoint's access logs are sent to StackDriver Logging. @@ -227,6 +236,9 @@ class DeployedIndex(proto.Message): automatic_resources = proto.Field( proto.MESSAGE, number=7, message=machine_resources.AutomaticResources, ) + dedicated_resources = proto.Field( + proto.MESSAGE, number=16, message=machine_resources.DedicatedResources, + ) enable_access_logging = proto.Field(proto.BOOL, number=8,) deployed_index_auth_config = proto.Field( proto.MESSAGE, number=9, message="DeployedIndexAuthConfig", diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 4120486a5e2..f40e74a3b28 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -397,7 +397,7 @@ class ModelContainerSpec(proto.Message): r"""Specification of a container for serving predictions. Some fields in this message correspond to fields in the `Kubernetes Container v1 core - specification `__. + specification `__. Attributes: image_uri (str): @@ -463,7 +463,7 @@ class ModelContainerSpec(proto.Message): this syntax with ``$$``; for example: $$(VARIABLE_NAME) This field corresponds to the ``command`` field of the Kubernetes Containers `v1 core - API `__. + API `__. args (Sequence[str]): Immutable. Specifies arguments for the command that runs when the container starts. This overrides the container's @@ -502,7 +502,7 @@ class ModelContainerSpec(proto.Message): this syntax with ``$$``; for example: $$(VARIABLE_NAME) This field corresponds to the ``args`` field of the Kubernetes Containers `v1 core - API `__. + API `__. env (Sequence[google.cloud.aiplatform_v1beta1.types.EnvVar]): Immutable. List of environment variables to set in the container. After the container starts running, code running @@ -535,7 +535,7 @@ class ModelContainerSpec(proto.Message): This field corresponds to the ``env`` field of the Kubernetes Containers `v1 core - API `__. + API `__. ports (Sequence[google.cloud.aiplatform_v1beta1.types.Port]): Immutable. List of ports to expose from the container. Vertex AI sends any prediction requests that it receives to @@ -558,7 +558,7 @@ class ModelContainerSpec(proto.Message): Vertex AI does not use ports other than the first one listed. This field corresponds to the ``ports`` field of the Kubernetes Containers `v1 core - API `__. + API `__. predict_route (str): Immutable. HTTP path on the container to send prediction requests to. Vertex AI forwards requests sent using diff --git a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py index 146fba7fd63..7fa47a9c66d 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py @@ -284,9 +284,10 @@ class ModelDeploymentMonitoringScheduleConfig(proto.Message): Attributes: monitor_interval (google.protobuf.duration_pb2.Duration): - Required. The model monitoring job running + Required. The model monitoring job scheduling interval. It will be rounded up to next full - hour. + hour. This defines how often the monitoring jobs + are triggered. """ monitor_interval = proto.Field( diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 77032803f90..beccba62e96 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -27,8 +27,7 @@ class Study(proto.Message): - r"""LINT.IfChange - A message representing a Study. + r"""A message representing a Study. Attributes: name (str): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index b8abbab31d1..c90d804f646 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -43,6 +43,7 @@ ) from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import index_endpoint from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint from google.cloud.aiplatform_v1beta1.types import index_endpoint_service