From 1645045b0c098a5012293761cbd2f645b0cf1eaf Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Fri, 25 Jun 2021 12:20:24 -0500 Subject: [PATCH 1/4] fix: Set prediction client when listing Endpoints --- google/cloud/aiplatform/models.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b287581431..76ff66d30c 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -324,6 +324,45 @@ def _create( credentials=credentials, ) + def _construct_sdk_resource_from_gapic( + self, + gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Endpoint": + """Given a GAPIC Endpoint object, return the SDK representation. + + Args: + gapic_resource (proto.Message): + A GAPIC representation of a Endpoint resource, usually + retrieved by a get_* or in a list_* API call. + project (str): + Optional. Project to construct Endpoint object from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to construct Endpoint object from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to construct Endpoint. + Overrides credentials set in aiplatform.init. + + Returns: + Endpoint: + An initialized Endpoint resource. + """ + endpoint = self._empty_constructor( + project=project, location=location, credentials=credentials + ) + + endpoint._gca_resource = gapic_resource + endpoint._prediction_client = self._instantiate_prediction_client( + location=location or initializer.global_config.location, + credentials=credentials, + ) + + return endpoint + @staticmethod def _allocate_traffic( traffic_split: Dict[str, int], traffic_percentage: int, From 7c4665f1f8d7db9e3fd56dc8eff34d09ec92ebea Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 29 Jun 2021 13:17:49 -0500 Subject: [PATCH 2/4] Address reviewer comments --- google/cloud/aiplatform/base.py | 11 ++-- google/cloud/aiplatform/models.py | 21 ++++++- tests/unit/aiplatform/test_endpoints.py | 73 +++++++++++++++++++++---- 3 files changed, 86 insertions(+), 19 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 732e6b9acf..2df7d7234e 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -819,8 +819,9 @@ def _sync_object_with_future_result( if value: setattr(self, attribute, value) + @classmethod def _construct_sdk_resource_from_gapic( - self, + cls, gapic_resource: proto.Message, project: Optional[str] = None, location: Optional[str] = None, @@ -846,7 +847,7 @@ def _construct_sdk_resource_from_gapic( VertexAiResourceNoun: An initialized SDK object that represents GAPIC type. """ - sdk_resource = self._empty_constructor( + sdk_resource = cls._empty_constructor( project=project, location=location, credentials=credentials ) sdk_resource._gca_resource = gapic_resource @@ -894,14 +895,14 @@ def _list( Returns: List[VertexAiResourceNoun] - A list of SDK resource objects """ - self = cls._empty_constructor( + resource = cls._empty_constructor( project=project, location=location, credentials=credentials ) # Fetch credentials once and re-use for all `_empty_constructor()` calls creds = initializer.global_config.credentials - resource_list_method = getattr(self.api_client, self._list_method) + resource_list_method = getattr(resource.api_client, resource._list_method) list_request = { "parent": initializer.global_config.common_location_path( @@ -916,7 +917,7 @@ def _list( resource_list = resource_list_method(request=list_request) or [] return [ - self._construct_sdk_resource_from_gapic( + cls._construct_sdk_resource_from_gapic( gapic_resource, project=project, location=location, credentials=creds ) for gapic_resource in resource_list diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 76ff66d30c..129efe8415 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -116,6 +116,11 @@ def __init__( resource_name=endpoint_name, ) self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) + + project, location = self._get_and_validate_project_location( + resource_name=self._gca_resource.name, project=project, location=location + ) + self._prediction_client = self._instantiate_prediction_client( location=location or initializer.global_config.location, credentials=credentials, @@ -324,8 +329,9 @@ def _create( credentials=credentials, ) + @classmethod def _construct_sdk_resource_from_gapic( - self, + cls, gapic_resource: proto.Message, project: Optional[str] = None, location: Optional[str] = None, @@ -351,12 +357,21 @@ def _construct_sdk_resource_from_gapic( Endpoint: An initialized Endpoint resource. """ - endpoint = self._empty_constructor( + endpoint = cls._empty_constructor( project=project, location=location, credentials=credentials ) endpoint._gca_resource = gapic_resource - endpoint._prediction_client = self._instantiate_prediction_client( + + project, location = endpoint._get_and_validate_project_location( + resource_name=gapic_resource.name, project=project, location=location + ) + + endpoint.project = project or initializer.global_config.project + endpoint.location = location or initializer.global_config.location + endpoint.credentials = credentials or initializer.global_config.credentials + + endpoint._prediction_client = cls._instantiate_prediction_client( location=location or initializer.global_config.location, credentials=credentials, ) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 03c3f38667..5303d9d803 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -66,7 +66,6 @@ _TEST_LOCATION = "us-central1" _TEST_LOCATION_2 = "europe-west4" -_TEST_ENDPOINT_NAME = "test-endpoint" _TEST_DISPLAY_NAME = "test-display-name" _TEST_DISPLAY_NAME_2 = "test-display-name-2" _TEST_ID = "1028944691210842416" @@ -76,6 +75,9 @@ _TEST_ENDPOINT_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}" ) +_TEST_ENDPOINT_NAME_ALT_LOCATION = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION_2}/endpoints/{_TEST_ID}" +) _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_MODEL_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ID}" @@ -139,6 +141,9 @@ kms_key_name=_TEST_ENCRYPTION_KEY_NAME ) +_TEST_ENDPOINT_GAPIC = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, name=_TEST_ENDPOINT_NAME +) _TEST_ENDPOINT_LIST = [ gca_endpoint.Endpoint( @@ -170,6 +175,19 @@ def get_endpoint_mock(): yield get_endpoint_mock +@pytest.fixture +def get_endpoint_alt_location_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME_ALT_LOCATION, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_endpoint_mock + + @pytest.fixture def get_endpoint_with_models_mock(): with mock.patch.object( @@ -293,14 +311,16 @@ def list_endpoints_mock(): @pytest.fixture -def create_client_mock(): +def create_endpoint_client_mock(): with mock.patch.object( initializer.global_config, "create_client", autospec=True, - ) as create_client_mock: - create_client_mock.return_value = mock.Mock( + ) as create_endpoint_client_mock: + endpoint_client_mock = mock.Mock( spec=endpoint_service_client.EndpointServiceClient ) - yield create_client_mock + endpoint_client_mock.get_endpoint.return_value = _TEST_ENDPOINT_GAPIC + create_endpoint_client_mock.return_value = endpoint_client_mock + yield create_endpoint_client_mock @pytest.fixture @@ -340,14 +360,14 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def test_constructor(self, create_client_mock): + def test_constructor(self, create_endpoint_client_mock): aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, ) models.Endpoint(_TEST_ENDPOINT_NAME) - create_client_mock.assert_has_calls( + create_endpoint_client_mock.assert_has_calls( [ mock.call( client_class=utils.EndpointClientWithOverride, @@ -382,20 +402,34 @@ def test_constructor_with_custom_project(self, get_endpoint_mock): ) get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) - def test_constructor_with_custom_location(self, get_endpoint_mock): + @pytest.mark.usefixtures("get_endpoint_mock") + def test_constructor_with_conflicting_location(self): + """get_endpoint_mock returns resource name with `_TEST_LOCATION` instead of `_TEST_LOCATION_2`""" + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises(RuntimeError) as err: + models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) + + assert err.match( + regexp=r"is provided, but different from the resource location" + ) + + def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID ) - get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) + get_endpoint_alt_location_mock.assert_called_with( + name=test_endpoint_resource_name + ) - def test_constructor_with_custom_credentials(self, create_client_mock): + def test_constructor_with_custom_credentials(self, create_endpoint_client_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() models.Endpoint(_TEST_ENDPOINT_NAME, credentials=creds) - create_client_mock.assert_has_calls( + create_endpoint_client_mock.assert_has_calls( [ mock.call( client_class=utils.EndpointClientWithOverride, @@ -1005,6 +1039,23 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync): any_order=True, ) + @pytest.mark.usefixtures("list_endpoints_mock") + def test_list_endpoint_has_prediction_client(self): + """Test call to Endpoint.list() and ensure Endpoints have prediction client set""" + ep_list = aiplatform.Endpoint.list(order_by=_TEST_LIST_ORDER_BY_CREATE_TIME) + + assert ep_list # Ensure list is not empty + + # Confirm every Endpoint object in the list has a prediction client + assert all( + [ + isinstance( + e._prediction_client, aiplatform.utils.PredictionClientWithOverride + ) + for e in ep_list + ] + ) + def test_list_endpoint_order_by_time(self, list_endpoints_mock): """Test call to Endpoint.list() and ensure list is returned in descending order of create_time""" From 69b80f508d5ebc9c40d1290ec4d9743e92befc07 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 29 Jun 2021 13:21:23 -0500 Subject: [PATCH 3/4] Remove redundant init() in TestEndpoints --- tests/unit/aiplatform/test_endpoints.py | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 5303d9d803..1d683715da 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -385,17 +385,14 @@ def test_constructor(self, create_endpoint_client_mock): ) def test_constructor_with_endpoint_id(self, get_endpoint_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(_TEST_ID) get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) def test_constructor_with_endpoint_name(self, get_endpoint_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(_TEST_ENDPOINT_NAME) get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) def test_constructor_with_custom_project(self, get_endpoint_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID @@ -405,7 +402,6 @@ def test_constructor_with_custom_project(self, get_endpoint_mock): @pytest.mark.usefixtures("get_endpoint_mock") def test_constructor_with_conflicting_location(self): """get_endpoint_mock returns resource name with `_TEST_LOCATION` instead of `_TEST_LOCATION_2`""" - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with pytest.raises(RuntimeError) as err: models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) @@ -415,7 +411,6 @@ def test_constructor_with_conflicting_location(self): ) def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID @@ -425,7 +420,6 @@ def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock): ) def test_constructor_with_custom_credentials(self, create_endpoint_client_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() models.Endpoint(_TEST_ENDPOINT_NAME, credentials=creds) @@ -474,7 +468,6 @@ def test_init_aiplatform_with_encryption_key_name_and_create_endpoint( @pytest.mark.usefixtures("get_endpoint_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create(self, create_endpoint_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) my_endpoint = models.Endpoint.create( display_name=_TEST_DISPLAY_NAME, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, @@ -497,7 +490,6 @@ def test_create(self, create_endpoint_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_with_description(self, create_endpoint_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) my_endpoint = models.Endpoint.create( display_name=_TEST_DISPLAY_NAME, description=_TEST_DESCRIPTION, sync=sync ) @@ -514,7 +506,6 @@ def test_create_with_description(self, create_endpoint_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(test_model, sync=sync) @@ -540,7 +531,6 @@ def test_deploy(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_display_name(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy( @@ -569,7 +559,6 @@ def test_deploy_with_display_name(self, deploy_model_mock, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_traffic_80(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync) @@ -581,7 +570,6 @@ def test_deploy_raise_error_traffic_80(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_traffic_120(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, traffic_percentage=120, sync=sync) @@ -590,7 +578,6 @@ def test_deploy_raise_error_traffic_120(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_traffic_negative(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, traffic_percentage=-18, sync=sync) @@ -599,7 +586,6 @@ def test_deploy_raise_error_traffic_negative(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_min_replica(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, min_replica_count=-1, sync=sync) @@ -608,7 +594,6 @@ def test_deploy_raise_error_min_replica(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_max_replica(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync) @@ -617,7 +602,6 @@ def test_deploy_raise_error_max_replica(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_deploy_raise_error_traffic_split(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, traffic_split={"a": 99}, sync=sync) @@ -625,7 +609,6 @@ def test_deploy_raise_error_traffic_split(self, sync): @pytest.mark.usefixtures("get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: @@ -658,7 +641,6 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_traffic_split(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: @@ -694,7 +676,6 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy( @@ -735,7 +716,6 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy( @@ -780,7 +760,6 @@ def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, syn @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_min_replica_count(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, min_replica_count=2, sync=sync) @@ -805,7 +784,6 @@ def test_deploy_with_min_replica_count(self, deploy_model_mock, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy_with_max_replica_count(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync) @@ -894,7 +872,6 @@ def test_unallocate_traffic(self, model1, model2, model3, deployed_model): @pytest.mark.parametrize("sync", [True, False]) def test_undeploy(self, undeploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: @@ -918,7 +895,6 @@ def test_undeploy(self, undeploy_model_mock, sync): @pytest.mark.parametrize("sync", [True, False]) def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( endpoint_service_client.EndpointServiceClient, "get_endpoint" ) as get_endpoint_mock: @@ -948,7 +924,6 @@ def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): @pytest.mark.parametrize("sync", [True, False]) def test_undeploy_raise_error_traffic_split_total(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_endpoint.undeploy( deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync @@ -958,7 +933,6 @@ def test_undeploy_raise_error_traffic_split_total(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_undeploy_raise_error_undeployed_model_traffic(self, sync): with pytest.raises(ValueError): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_endpoint.undeploy( deployed_model_id="model1", @@ -967,7 +941,6 @@ def test_undeploy_raise_error_undeployed_model_traffic(self, sync): ) def test_predict(self, get_endpoint_mock, predict_client_predict_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ID) test_prediction = test_endpoint.predict( @@ -986,7 +959,6 @@ def test_predict(self, get_endpoint_mock, predict_client_predict_mock): ) def test_explain(self, get_endpoint_mock, predict_client_explain_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ID) test_prediction = test_endpoint.explain( @@ -1012,7 +984,6 @@ def test_explain(self, get_endpoint_mock, predict_client_explain_mock): ) def test_list_models(self, get_endpoint_with_models_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) ept = aiplatform.Endpoint(_TEST_ID) my_models = ept.list_models() @@ -1022,7 +993,6 @@ def test_list_models(self, get_endpoint_with_models_mock): @pytest.mark.usefixtures("get_endpoint_with_models_mock") @pytest.mark.parametrize("sync", [True, False]) def test_undeploy_all(self, sdk_private_undeploy_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) ept = aiplatform.Endpoint(_TEST_ID) ept.undeploy_all(sync=sync) @@ -1101,7 +1071,6 @@ def test_list_endpoint_order_by_display_name(self, list_endpoints_mock): def test_delete_endpoint_without_force( self, sdk_undeploy_all_mock, delete_endpoint_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) ept = aiplatform.Endpoint(_TEST_ID) ept.delete(sync=sync) @@ -1119,7 +1088,6 @@ def test_delete_endpoint_without_force( def test_delete_endpoint_with_force( self, sdk_undeploy_all_mock, delete_endpoint_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) ept = aiplatform.Endpoint(_TEST_ID) ept.delete(force=True, sync=sync) From 40dcc4ca3c4c3fc5883815278d8aac3a52e6488f Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 29 Jun 2021 13:58:55 -0500 Subject: [PATCH 4/4] Update location passed to _instantiate_prediction_client() --- google/cloud/aiplatform/models.py | 18 ++---------------- tests/unit/aiplatform/test_endpoints.py | 6 ++++-- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 129efe8415..56cad667cb 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -117,13 +117,8 @@ def __init__( ) self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) - project, location = self._get_and_validate_project_location( - resource_name=self._gca_resource.name, project=project, location=location - ) - self._prediction_client = self._instantiate_prediction_client( - location=location or initializer.global_config.location, - credentials=credentials, + location=self.location, credentials=credentials, ) @property @@ -363,17 +358,8 @@ def _construct_sdk_resource_from_gapic( endpoint._gca_resource = gapic_resource - project, location = endpoint._get_and_validate_project_location( - resource_name=gapic_resource.name, project=project, location=location - ) - - endpoint.project = project or initializer.global_config.project - endpoint.location = location or initializer.global_config.location - endpoint.credentials = credentials or initializer.global_config.credentials - endpoint._prediction_client = cls._instantiate_prediction_client( - location=location or initializer.global_config.location, - credentials=credentials, + location=endpoint.location, credentials=credentials, ) return endpoint diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 1d683715da..e9f7de971a 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -401,10 +401,12 @@ def test_constructor_with_custom_project(self, get_endpoint_mock): @pytest.mark.usefixtures("get_endpoint_mock") def test_constructor_with_conflicting_location(self): - """get_endpoint_mock returns resource name with `_TEST_LOCATION` instead of `_TEST_LOCATION_2`""" + """Passing a full resource name with `_TEST_LOCATION` and providing `_TEST_LOCATION_2` as location""" with pytest.raises(RuntimeError) as err: - models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) + models.Endpoint( + endpoint_name=_TEST_ENDPOINT_NAME, location=_TEST_LOCATION_2 + ) assert err.match( regexp=r"is provided, but different from the resource location"