Skip to content

Commit

Permalink
fix: Endpoint.undeploy_all() doesn't undeploy all models
Browse files Browse the repository at this point in the history
#1441

PiperOrigin-RevId: 500253890
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 6, 2023
1 parent f87fef0 commit 9fb24d7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
12 changes: 11 additions & 1 deletion google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,11 +1709,21 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint":
"""
self._sync_gca_resource()

models_to_undeploy = sorted( # Undeploy zero traffic models first
models_in_traffic_split = sorted( # Undeploy zero traffic models first
self._gca_resource.traffic_split.keys(),
key=lambda id: self._gca_resource.traffic_split[id],
)

# Some deployed models may not in the traffic_split dict.
# These models have 0% traffic and should be undeployed first.
models_not_in_traffic_split = [
deployed_model.id
for deployed_model in self._gca_resource.deployed_models
if deployed_model.id not in models_in_traffic_split
]

models_to_undeploy = models_not_in_traffic_split + models_in_traffic_split

for deployed_model in models_to_undeploy:
self._undeploy(deployed_model_id=deployed_model, sync=sync)

Expand Down
16 changes: 5 additions & 11 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,12 @@
"m2": 10,
"m3": 30,
"m4": 0,
"m5": 5,
"m6": 8,
"m7": 7,
"m5": 20,
}
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"]
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m2", "m5", "m3", "m1"]
_TEST_LONG_DEPLOYED_MODELS = [
gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name")
for id in _TEST_LONG_TRAFFIC_SPLIT.keys()
for id in ["m1", "m2", "m3", "m4", "m5", "m6", "m7"]
]

_TEST_MACHINE_TYPE = "n1-standard-32"
Expand Down Expand Up @@ -1861,11 +1859,6 @@ def test_list_models(self, get_endpoint_with_models_mock):
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_all(self, sdk_private_undeploy_mock, sync):

# Ensure mock traffic split deployed model IDs are same as expected IDs
assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set(
_TEST_LONG_TRAFFIC_SPLIT.keys()
)

ept = aiplatform.Endpoint(_TEST_ID)
ept.undeploy_all(sync=sync)

Expand All @@ -1874,10 +1867,11 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync):

# undeploy_all() results in an undeploy() call for each deployed_model
# Models are undeployed in ascending order of traffic percentage
expected_models_to_undeploy = ["m6", "m7"] + _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
sdk_private_undeploy_mock.assert_has_calls(
[
mock.call(deployed_model_id=deployed_model_id, sync=sync)
for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
for deployed_model_id in expected_models_to_undeploy
],
)

Expand Down

0 comments on commit 9fb24d7

Please sign in to comment.