Skip to content

Commit

Permalink
[SDK] fix grpc related bugs in Python SDK (#2398)
Browse files Browse the repository at this point in the history
* fix: fix bugs in report_metrics.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: fix bugs in tune.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: fix bugs in get_trial_metrics.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: update .gitignore and setup.py.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: update Makefile.

Signed-off-by: Electronic-Waste <[email protected]>

* feat: add report_metrics_test.py.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: fix lint error.

Signed-off-by: Electronic-Waste <[email protected]>

* feat: add UTs for get_trial_metrics.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: update post_gen.py.

Signed-off-by: Electronic-Waste <[email protected]>

* refactor: rebase to master.

Signed-off-by: Electronic-Waste <[email protected]>

* test(sdk): use single katib_client.

Signed-off-by: Electronic-Waste <[email protected]>

* fix(sdk): add TODO for import rewrite.

Signed-off-by: Electronic-Waste <[email protected]>

* fix(sdk): fix lint error with black.

Signed-off-by: Electronic-Waste <[email protected]>

* fix(sdk): fix lint error with isort.

Signed-off-by: Electronic-Waste <[email protected]>

* fix(sdk): reformat import in katib_client_test.py.

Signed-off-by: Electronic-Waste <[email protected]>

---------

Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste authored Aug 23, 2024
1 parent 0e2ba6e commit a524f33
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 50 deletions.
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,17 @@ ifeq ("$(wildcard $(TEST_TENSORFLOW_EVENT_FILE_PATH))", "")
python examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py --epochs 5 --batch-size 200 --log-path $(TEST_TENSORFLOW_EVENT_FILE_PATH)
endif

# TODO(Electronic-Waste): Remove the import rewrite when protobuf supports `python_package` option.
# REF: https://github.com/protocolbuffers/protobuf/issues/7061
pytest: prepare-pytest prepare-pytest-testdata
pytest ./test/unit/v1beta1/suggestion --ignore=./test/unit/v1beta1/suggestion/test_skopt_service.py
pytest ./test/unit/v1beta1/earlystopping
pytest ./test/unit/v1beta1/metricscollector
cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
cp ./pkg/apis/manager/v1beta1/python/api_pb2_grpc.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
sed -i "s/api_pb2/kubeflow\.katib\.katib_api_pb2/g" ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
pytest ./sdk/python/v1beta1/kubeflow/katib
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py

# The skopt service doesn't work appropriately with Python 3.11.
# So, we need to run the test with Python 3.9.
Expand Down
4 changes: 2 additions & 2 deletions hack/gen-python-sdk/post_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py":
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
lines.append("# Import Katib report metrics functions\n")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n")
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
Expand Down
1 change: 1 addition & 0 deletions sdk/python/v1beta1/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ dist/

# Katib gRPC APIs
kubeflow/katib/katib_api_pb2.py
kubeflow/katib/katib_api_pb2_grpc.py
4 changes: 3 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@

# Import Katib API client.
from kubeflow.katib.api.katib_client import KatibClient
# Import Katib report metrics functionsfrom kubeflow.katib.api.report_metrics import report_metrics# Import Katib helper functions.
# Import Katib report metrics functions
from kubeflow.katib.api.report_metrics import report_metrics
# Import Katib helper functions.
import kubeflow.katib.api.search as search
# Import Katib helper constants.
from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW
Expand Down
30 changes: 14 additions & 16 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib import models
from kubeflow.katib.api_client import ApiClient
from kubeflow.katib.constants import constants
Expand Down Expand Up @@ -1305,21 +1306,18 @@ def get_trial_metrics(

namespace = namespace or self.namespace

db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)

return observation_logs.observation_log.metric_logs
return observation_logs.observation_log.metric_logs
71 changes: 70 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional
from unittest.mock import Mock, patch

import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import pytest
from kubeflow.katib import (
KatibClient,
Expand Down Expand Up @@ -38,6 +39,24 @@ def create_namespaced_custom_object_response(*args, **kwargs):
return {"metadata": {"name": "12345-experiment-mnist-ci-test"}}


def get_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError
elif args[0].trial_name == "invalid":
raise RuntimeError
else:
return katib_api_pb2.GetObservationLogReply(
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result", value="0.99"),
)
]
)
)


def generate_trial_template() -> V1beta1TrialTemplate:
trial_spec = {
"apiVersion": "batch/v1",
Expand Down Expand Up @@ -223,6 +242,34 @@ def create_experiment(
]


test_get_trial_metrics_data = [
(
"valid trial name",
{"name": "example", "namespace": "valid", "timeout": constants.DEFAULT_TIMEOUT},
[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result", value="0.99"),
)
],
),
(
"invalid trial name",
{
"name": "invalid",
"namespace": "invalid",
"timeout": constants.DEFAULT_TIMEOUT,
},
RuntimeError,
),
(
"GetObservationLog timeout error",
{"name": "example", "namespace": "valid", "timeout": 0},
RuntimeError,
),
]


@pytest.fixture
def katib_client():
with patch(
Expand All @@ -232,7 +279,12 @@ def katib_client():
side_effect=create_namespaced_custom_object_response
)
),
), patch("kubernetes.config.load_kube_config", return_value=Mock()):
), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch(
"kubeflow.katib.katib_api_pb2_grpc.DBManagerStub",
return_value=Mock(
GetObservationLog=Mock(side_effect=get_observation_log_response)
),
):
client = KatibClient()
yield client

Expand All @@ -251,3 +303,20 @@ def test_create_experiment(katib_client, test_name, kwargs, expected_output):
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


@pytest.mark.parametrize(
"test_name,kwargs,expected_output", test_get_trial_metrics_data
)
def test_get_trial_metrics(katib_client, test_name, kwargs, expected_output):
"""
test get_trial_metrics function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
metrics = katib_client.get_trial_metrics(**kwargs)
for i in range(len(metrics)):
assert metrics[i] == expected_output[i]
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
54 changes: 25 additions & 29 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib.constants import constants
from kubeflow.katib.utils import utils

Expand All @@ -38,9 +39,9 @@ def report_metrics(
timeout: Optional, gRPC API Server timeout in seconds to report metrics.
Raises:
ValueError: The Trial name is not passed to environment variables.
RuntimeError: Unable to push Trial metrics to Katib DB or
ValueError: The Trial name is not passed to environment variables or
metrics value has incorrect format (cannot be converted to type `float`).
RuntimeError: Unable to push Trial metrics to Katib DB.
"""

# Get Trial's namespace and name
Expand All @@ -50,37 +51,32 @@ def report_metrics(
raise ValueError("The Trial name is not passed to environment variables")

# Get channel for grpc call to db manager
db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

# Validate metrics value in dict
for value in metrics.values():
utils.validate_metrics_value(value)

# Dial katib db manager to report metrics
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_logs=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(
name=name, value=str(value)
),
)
for name, value in metrics.items()
]
),
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name, value=str(value)),
)
for name, value in metrics.items()
]
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
104 changes: 104 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import patch

import pytest
from kubeflow.katib import report_metrics
from kubeflow.katib.constants import constants

TEST_RESULT_SUCCESS = "success"
ENV_VARIABLE_EMPTY = True
ENV_VARIABLE_NOT_EMPTY = False


def report_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError


test_report_metrics_data = [
(
"valid metrics with float type",
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"valid metrics with string type",
{"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"valid metrics with int type",
{"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"ReportObservationLog timeout error",
{"metrics": {"result": 0.99}, "timeout": 0},
RuntimeError,
ENV_VARIABLE_NOT_EMPTY,
),
(
"invalid metrics with type string",
{"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT},
ValueError,
ENV_VARIABLE_NOT_EMPTY,
),
(
"Trial name is not passed to env variables",
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
ValueError,
ENV_VARIABLE_EMPTY,
),
]


@pytest.fixture
def mock_getenv(request):
with patch("os.getenv") as mock:
if request.param is ENV_VARIABLE_EMPTY:
mock.side_effect = ValueError
else:
mock.return_value = "example"
yield mock


@pytest.fixture
def mock_get_current_k8s_namespace():
with patch("kubeflow.katib.utils.utils.get_current_k8s_namespace") as mock:
mock.return_value = "test"
yield mock


@pytest.fixture
def mock_report_observation_log():
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
mock_instance = mock.return_value
mock_instance.ReportObservationLog.side_effect = report_observation_log_response
yield mock_instance


@pytest.mark.parametrize(
"test_name,kwargs,expected_output,mock_getenv",
test_report_metrics_data,
indirect=["mock_getenv"],
)
def test_report_metrics(
test_name,
kwargs,
expected_output,
mock_getenv,
mock_get_current_k8s_namespace,
mock_report_observation_log,
):
"""
test report_metrics function
"""
print("\n\nExecuting test:", test_name)
try:
report_metrics(**kwargs)
assert expected_output == TEST_RESULT_SUCCESS
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
Loading

0 comments on commit a524f33

Please sign in to comment.