diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 6a6bbee463..de3be61f6f 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1709,11 +1709,21 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint": """ self._sync_gca_resource() - models_to_undeploy = sorted( # Undeploy zero traffic models first + models_in_traffic_split = sorted( # Undeploy zero traffic models first self._gca_resource.traffic_split.keys(), key=lambda id: self._gca_resource.traffic_split[id], ) + # Some deployed models may not in the traffic_split dict. + # These models have 0% traffic and should be undeployed first. + models_not_in_traffic_split = [ + deployed_model.id + for deployed_model in self._gca_resource.deployed_models + if deployed_model.id not in models_in_traffic_split + ] + + models_to_undeploy = models_not_in_traffic_split + models_in_traffic_split + for deployed_model in models_to_undeploy: self._undeploy(deployed_model_id=deployed_model, sync=sync) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 46c1959d90..90d6da8455 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -102,14 +102,12 @@ "m2": 10, "m3": 30, "m4": 0, - "m5": 5, - "m6": 8, - "m7": 7, + "m5": 20, } -_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"] +_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m2", "m5", "m3", "m1"] _TEST_LONG_DEPLOYED_MODELS = [ gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name") - for id in _TEST_LONG_TRAFFIC_SPLIT.keys() + for id in ["m1", "m2", "m3", "m4", "m5", "m6", "m7"] ] _TEST_MACHINE_TYPE = "n1-standard-32" @@ -1861,11 +1859,6 @@ def test_list_models(self, get_endpoint_with_models_mock): @pytest.mark.parametrize("sync", [True, False]) def test_undeploy_all(self, sdk_private_undeploy_mock, sync): - # Ensure mock traffic split deployed model IDs are same as expected IDs - assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set( - _TEST_LONG_TRAFFIC_SPLIT.keys() - ) - ept = aiplatform.Endpoint(_TEST_ID) ept.undeploy_all(sync=sync) @@ -1874,10 +1867,11 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync): # undeploy_all() results in an undeploy() call for each deployed_model # Models are undeployed in ascending order of traffic percentage + expected_models_to_undeploy = ["m6", "m7"] + _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS sdk_private_undeploy_mock.assert_has_calls( [ mock.call(deployed_model_id=deployed_model_id, sync=sync) - for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS + for deployed_model_id in expected_models_to_undeploy ], )