Skip to content

Commit

Permalink
fix: Improve handling of undeploying model without redistributing rem…
Browse files Browse the repository at this point in the history
…aining traffic (#898)

- Add informative error when undeploying a model with traffic from an Endpoint with multiple deployed models, without providing a new traffic split.
- Improve accuracy of docstring for `Endpoint.undeploy()`
- Add tests to cover CUJs

Fixes [b/198290421](http://b/198290421) 🦕
  • Loading branch information
vinnysenthil authored Dec 16, 2021
1 parent 321cf9e commit 8a8a4fa
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 16 deletions.
39 changes: 29 additions & 10 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,21 +997,22 @@ def undeploy(
) -> None:
"""Undeploys a deployed model.
Proportionally adjusts the traffic_split among the remaining deployed
models of the endpoint.
The model to be undeployed should have no traffic or user must provide
a new traffic_split with the remaining deployed models. Refer
to `Endpoint.traffic_split` for the current traffic split mapping.
Args:
deployed_model_id (str):
Required. The ID of the DeployedModel to be undeployed from the
Endpoint.
traffic_split (Dict[str, int]):
Optional. A map from a DeployedModel's ID to the percentage of
Optional. A map of DeployedModel IDs to the percentage of
this Endpoint's traffic that should be forwarded to that DeployedModel.
If a DeployedModel's ID is not listed in this map, then it receives
no traffic. The traffic percentage values must add up to 100, or
map must be empty if the Endpoint is to not accept any traffic at
the moment. Key for model being deployed is "0". Should not be
provided if traffic_percentage is provided.
Required if undeploying a model with non-zero traffic from an Endpoint
with multiple deployed models. The traffic percentage values must add
up to 100, or map must be empty if the Endpoint is to not accept any traffic
at the moment. If a DeployedModel's ID is not listed in this map, then it
receives no traffic.
metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as
metadata.
Expand All @@ -1026,6 +1027,19 @@ def undeploy(
"Sum of all traffic within traffic split needs to be 100."
)

# Two or more models deployed to Endpoint and remaining traffic will be zero
elif (
len(self.traffic_split) > 1
and deployed_model_id in self._gca_resource.traffic_split
and self._gca_resource.traffic_split[deployed_model_id] == 100
):
raise ValueError(
f"Undeploying deployed model '{deployed_model_id}' would leave the remaining "
"traffic split at 0%. Traffic split must add up to 100% when models are "
"deployed. Please undeploy the other models first or provide an updated "
"traffic_split."
)

self._undeploy(
deployed_model_id=deployed_model_id,
traffic_split=traffic_split,
Expand Down Expand Up @@ -1282,8 +1296,13 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint":
"""
self._sync_gca_resource()

for deployed_model in self._gca_resource.deployed_models:
self._undeploy(deployed_model_id=deployed_model.id, sync=sync)
models_to_undeploy = sorted( # Undeploy zero traffic models first
self._gca_resource.traffic_split.keys(),
key=lambda id: self._gca_resource.traffic_split[id],
)

for deployed_model in models_to_undeploy:
self._undeploy(deployed_model_id=deployed_model, sync=sync)

return self

Expand Down
114 changes: 108 additions & 6 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import copy
import pytest

from unittest import mock
Expand Down Expand Up @@ -56,8 +57,10 @@

_TEST_DISPLAY_NAME = "test-display-name"
_TEST_DISPLAY_NAME_2 = "test-display-name-2"
_TEST_DISPLAY_NAME_3 = "test-display-name-3"
_TEST_ID = "1028944691210842416"
_TEST_ID_2 = "4366591682456584192"
_TEST_ID_3 = "5820582938582924817"
_TEST_DESCRIPTION = "test-description"

_TEST_ENDPOINT_NAME = (
Expand All @@ -80,6 +83,24 @@
_TEST_DEPLOYED_MODELS = [
gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME),
gca_endpoint.DeployedModel(id=_TEST_ID_2, display_name=_TEST_DISPLAY_NAME_2),
gca_endpoint.DeployedModel(id=_TEST_ID_3, display_name=_TEST_DISPLAY_NAME_3),
]

_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0}

_TEST_LONG_TRAFFIC_SPLIT = {
"m1": 40,
"m2": 10,
"m3": 30,
"m4": 0,
"m5": 5,
"m6": 8,
"m7": 7,
}
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"]
_TEST_LONG_DEPLOYED_MODELS = [
gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name")
for id in _TEST_LONG_TRAFFIC_SPLIT.keys()
]

_TEST_MACHINE_TYPE = "n1-standard-32"
Expand Down Expand Up @@ -200,6 +221,21 @@ def get_endpoint_with_models_mock():
display_name=_TEST_DISPLAY_NAME,
name=_TEST_ENDPOINT_NAME,
deployed_models=_TEST_DEPLOYED_MODELS,
traffic_split=_TEST_TRAFFIC_SPLIT,
)
yield get_endpoint_mock


@pytest.fixture
def get_endpoint_with_many_models_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
name=_TEST_ENDPOINT_NAME,
deployed_models=_TEST_LONG_DEPLOYED_MODELS,
traffic_split=_TEST_LONG_TRAFFIC_SPLIT,
)
yield get_endpoint_mock

Expand Down Expand Up @@ -990,23 +1026,84 @@ def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync):
@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_raise_error_traffic_split_total(self, sync):
with pytest.raises(ValueError):
with pytest.raises(ValueError) as e:
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_endpoint.undeploy(
deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync
)

assert e.match("Sum of all traffic within traffic split needs to be 100.")

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_raise_error_undeployed_model_traffic(self, sync):
with pytest.raises(ValueError):
with pytest.raises(ValueError) as e:
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_endpoint.undeploy(
deployed_model_id="model1",
traffic_split={"model1": 50, "model2": 50},
sync=sync,
)

assert e.match("Model being undeployed should have 0 traffic.")

@pytest.mark.usefixtures("get_endpoint_with_models_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_raises_error_on_zero_leftover_traffic(self, sync):
"""
Attempting to undeploy model with 100% traffic on an Endpoint with
multiple models deployed without an updated traffic_split should
raise an informative error.
"""

traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_2]

assert traffic_remaining == 100 # Confirm this model has all traffic
assert sum(_TEST_TRAFFIC_SPLIT.values()) == 100 # Mock traffic sums to 100%

with pytest.raises(ValueError) as e:
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_endpoint.undeploy(
deployed_model_id=_TEST_ID_2, sync=sync,
)

assert e.match(
f"Undeploying deployed model '{_TEST_ID_2}' would leave the remaining "
f"traffic split at 0%."
)

@pytest.mark.usefixtures("get_endpoint_with_models_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_zero_traffic_model_without_new_traffic_split(
self, undeploy_model_mock, sync
):
"""
Attempting to undeploy model with zero traffic without providing
a new traffic split should not raise any errors.
"""

traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_3]

assert not traffic_remaining # Confirm there is zero traffic

test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_endpoint.undeploy(
deployed_model_id=_TEST_ID_3, sync=sync,
)

if not sync:
test_endpoint.wait()

expected_new_traffic_split = copy.deepcopy(_TEST_TRAFFIC_SPLIT)
expected_new_traffic_split.pop(_TEST_ID_3)

undeploy_model_mock.assert_called_once_with(
endpoint=test_endpoint.resource_name,
deployed_model_id=_TEST_ID_3,
traffic_split=expected_new_traffic_split,
metadata=(),
)

def test_predict(self, get_endpoint_mock, predict_client_predict_mock):

test_endpoint = models.Endpoint(_TEST_ID)
Expand Down Expand Up @@ -1057,23 +1154,28 @@ def test_list_models(self, get_endpoint_with_models_mock):

assert my_models == _TEST_DEPLOYED_MODELS

@pytest.mark.usefixtures("get_endpoint_with_models_mock")
@pytest.mark.usefixtures("get_endpoint_with_many_models_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_undeploy_all(self, sdk_private_undeploy_mock, sync):

# Ensure mock traffic split deployed model IDs are same as expected IDs
assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set(
_TEST_LONG_TRAFFIC_SPLIT.keys()
)

ept = aiplatform.Endpoint(_TEST_ID)
ept.undeploy_all(sync=sync)

if not sync:
ept.wait()

# undeploy_all() results in an undeploy() call for each deployed_model
# Models are undeployed in ascending order of traffic percentage
sdk_private_undeploy_mock.assert_has_calls(
[
mock.call(deployed_model_id=deployed_model.id, sync=sync)
for deployed_model in _TEST_DEPLOYED_MODELS
mock.call(deployed_model_id=deployed_model_id, sync=sync)
for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
],
any_order=True,
)

@pytest.mark.usefixtures("list_endpoints_mock")
Expand Down

0 comments on commit 8a8a4fa

Please sign in to comment.