Skip to content

Commit a524f33

Browse files
[SDK] fix grpc related bugs in Python SDK (#2398)
* fix: fix bugs in report_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix bugs in tune. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix bugs in get_trial_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: update .gitignore and setup.py. Signed-off-by: Electronic-Waste <[email protected]> * fix: update Makefile. Signed-off-by: Electronic-Waste <[email protected]> * feat: add report_metrics_test.py. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix lint error. Signed-off-by: Electronic-Waste <[email protected]> * feat: add UTs for get_trial_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: update post_gen.py. Signed-off-by: Electronic-Waste <[email protected]> * refactor: rebase to master. Signed-off-by: Electronic-Waste <[email protected]> * test(sdk): use single katib_client. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): add TODO for import rewrite. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): fix lint error with black. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): fix lint error with isort. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): reformat import in katib_client_test.py. Signed-off-by: Electronic-Waste <[email protected]> --------- Signed-off-by: Electronic-Waste <[email protected]>
1 parent 0e2ba6e commit a524f33

File tree

9 files changed

+240
-50
lines changed

9 files changed

+240
-50
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,17 @@ ifeq ("$(wildcard $(TEST_TENSORFLOW_EVENT_FILE_PATH))", "")
166166
python examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py --epochs 5 --batch-size 200 --log-path $(TEST_TENSORFLOW_EVENT_FILE_PATH)
167167
endif
168168

169+
# TODO(Electronic-Waste): Remove the import rewrite when protobuf supports `python_package` option.
170+
# REF: https://github.com/protocolbuffers/protobuf/issues/7061
169171
pytest: prepare-pytest prepare-pytest-testdata
170172
pytest ./test/unit/v1beta1/suggestion --ignore=./test/unit/v1beta1/suggestion/test_skopt_service.py
171173
pytest ./test/unit/v1beta1/earlystopping
172174
pytest ./test/unit/v1beta1/metricscollector
173175
cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
176+
cp ./pkg/apis/manager/v1beta1/python/api_pb2_grpc.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
177+
sed -i "s/api_pb2/kubeflow\.katib\.katib_api_pb2/g" ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
174178
pytest ./sdk/python/v1beta1/kubeflow/katib
175-
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
179+
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
176180

177181
# The skopt service doesn't work appropriately with Python 3.11.
178182
# So, we need to run the test with Python 3.9.

hack/gen-python-sdk/post_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
4141
if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py":
4242
lines.append("# Import Katib API client.\n")
4343
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
44-
lines.append("# Import Katib report metrics functions")
45-
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
44+
lines.append("# Import Katib report metrics functions\n")
45+
lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n")
4646
lines.append("# Import Katib helper functions.\n")
4747
lines.append("import kubeflow.katib.api.search as search\n")
4848
lines.append("# Import Katib helper constants.\n")

sdk/python/v1beta1/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ dist/
33

44
# Katib gRPC APIs
55
kubeflow/katib/katib_api_pb2.py
6+
kubeflow/katib/katib_api_pb2_grpc.py

sdk/python/v1beta1/kubeflow/katib/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@
7171

7272
# Import Katib API client.
7373
from kubeflow.katib.api.katib_client import KatibClient
74-
# Import Katib report metrics functionsfrom kubeflow.katib.api.report_metrics import report_metrics# Import Katib helper functions.
74+
# Import Katib report metrics functions
75+
from kubeflow.katib.api.report_metrics import report_metrics
76+
# Import Katib helper functions.
7577
import kubeflow.katib.api.search as search
7678
# Import Katib helper constants.
7779
from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import grpc
2323
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
24+
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
2425
from kubeflow.katib import models
2526
from kubeflow.katib.api_client import ApiClient
2627
from kubeflow.katib.constants import constants
@@ -1305,21 +1306,18 @@ def get_trial_metrics(
13051306

13061307
namespace = namespace or self.namespace
13071308

1308-
db_manager_address = db_manager_address.split(":")
1309-
channel = grpc.beta.implementations.insecure_channel(
1310-
db_manager_address[0], int(db_manager_address[1])
1311-
)
1309+
channel = grpc.insecure_channel(db_manager_address)
13121310

1313-
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
1314-
try:
1315-
# When metric name is empty, we select all logs from the Katib DB.
1316-
observation_logs = client.GetObservationLog(
1317-
katib_api_pb2.GetObservationLogRequest(trial_name=name),
1318-
timeout=timeout,
1319-
)
1320-
except Exception as e:
1321-
raise RuntimeError(
1322-
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
1323-
)
1311+
client = katib_api_pb2_grpc.DBManagerStub(channel)
1312+
try:
1313+
# When metric name is empty, we select all logs from the Katib DB.
1314+
observation_logs = client.GetObservationLog(
1315+
katib_api_pb2.GetObservationLogRequest(trial_name=name),
1316+
timeout=timeout,
1317+
)
1318+
except Exception as e:
1319+
raise RuntimeError(
1320+
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
1321+
)
13241322

1325-
return observation_logs.observation_log.metric_logs
1323+
return observation_logs.observation_log.metric_logs

sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional
33
from unittest.mock import Mock, patch
44

5+
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
56
import pytest
67
from kubeflow.katib import (
78
KatibClient,
@@ -38,6 +39,24 @@ def create_namespaced_custom_object_response(*args, **kwargs):
3839
return {"metadata": {"name": "12345-experiment-mnist-ci-test"}}
3940

4041

42+
def get_observation_log_response(*args, **kwargs):
43+
if kwargs.get("timeout") == 0:
44+
raise TimeoutError
45+
elif args[0].trial_name == "invalid":
46+
raise RuntimeError
47+
else:
48+
return katib_api_pb2.GetObservationLogReply(
49+
observation_log=katib_api_pb2.ObservationLog(
50+
metric_logs=[
51+
katib_api_pb2.MetricLog(
52+
time_stamp="2024-07-29T15:09:08Z",
53+
metric=katib_api_pb2.Metric(name="result", value="0.99"),
54+
)
55+
]
56+
)
57+
)
58+
59+
4160
def generate_trial_template() -> V1beta1TrialTemplate:
4261
trial_spec = {
4362
"apiVersion": "batch/v1",
@@ -223,6 +242,34 @@ def create_experiment(
223242
]
224243

225244

245+
test_get_trial_metrics_data = [
246+
(
247+
"valid trial name",
248+
{"name": "example", "namespace": "valid", "timeout": constants.DEFAULT_TIMEOUT},
249+
[
250+
katib_api_pb2.MetricLog(
251+
time_stamp="2024-07-29T15:09:08Z",
252+
metric=katib_api_pb2.Metric(name="result", value="0.99"),
253+
)
254+
],
255+
),
256+
(
257+
"invalid trial name",
258+
{
259+
"name": "invalid",
260+
"namespace": "invalid",
261+
"timeout": constants.DEFAULT_TIMEOUT,
262+
},
263+
RuntimeError,
264+
),
265+
(
266+
"GetObservationLog timeout error",
267+
{"name": "example", "namespace": "valid", "timeout": 0},
268+
RuntimeError,
269+
),
270+
]
271+
272+
226273
@pytest.fixture
227274
def katib_client():
228275
with patch(
@@ -232,7 +279,12 @@ def katib_client():
232279
side_effect=create_namespaced_custom_object_response
233280
)
234281
),
235-
), patch("kubernetes.config.load_kube_config", return_value=Mock()):
282+
), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch(
283+
"kubeflow.katib.katib_api_pb2_grpc.DBManagerStub",
284+
return_value=Mock(
285+
GetObservationLog=Mock(side_effect=get_observation_log_response)
286+
),
287+
):
236288
client = KatibClient()
237289
yield client
238290

@@ -251,3 +303,20 @@ def test_create_experiment(katib_client, test_name, kwargs, expected_output):
251303
except Exception as e:
252304
assert type(e) is expected_output
253305
print("test execution complete")
306+
307+
308+
@pytest.mark.parametrize(
309+
"test_name,kwargs,expected_output", test_get_trial_metrics_data
310+
)
311+
def test_get_trial_metrics(katib_client, test_name, kwargs, expected_output):
312+
"""
313+
test get_trial_metrics function of katib client
314+
"""
315+
print("\n\nExecuting test:", test_name)
316+
try:
317+
metrics = katib_client.get_trial_metrics(**kwargs)
318+
for i in range(len(metrics)):
319+
assert metrics[i] == expected_output[i]
320+
except Exception as e:
321+
assert type(e) is expected_output
322+
print("test execution complete")

sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
21+
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
2122
from kubeflow.katib.constants import constants
2223
from kubeflow.katib.utils import utils
2324

@@ -38,9 +39,9 @@ def report_metrics(
3839
timeout: Optional, gRPC API Server timeout in seconds to report metrics.
3940
4041
Raises:
41-
ValueError: The Trial name is not passed to environment variables.
42-
RuntimeError: Unable to push Trial metrics to Katib DB or
42+
ValueError: The Trial name is not passed to environment variables or
4343
metrics value has incorrect format (cannot be converted to type `float`).
44+
RuntimeError: Unable to push Trial metrics to Katib DB.
4445
"""
4546

4647
# Get Trial's namespace and name
@@ -50,37 +51,32 @@ def report_metrics(
5051
raise ValueError("The Trial name is not passed to environment variables")
5152

5253
# Get channel for grpc call to db manager
53-
db_manager_address = db_manager_address.split(":")
54-
channel = grpc.beta.implementations.insecure_channel(
55-
db_manager_address[0], int(db_manager_address[1])
56-
)
54+
channel = grpc.insecure_channel(db_manager_address)
5755

5856
# Validate metrics value in dict
5957
for value in metrics.values():
6058
utils.validate_metrics_value(value)
6159

6260
# Dial katib db manager to report metrics
63-
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
64-
try:
65-
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
66-
client.ReportObservationLog(
67-
request=katib_api_pb2.ReportObservationLogRequest(
68-
trial_name=name,
69-
observation_logs=katib_api_pb2.ObservationLog(
70-
metric_logs=[
71-
katib_api_pb2.MetricLog(
72-
time_stamp=timestamp,
73-
metric=katib_api_pb2.Metric(
74-
name=name, value=str(value)
75-
),
76-
)
77-
for name, value in metrics.items()
78-
]
79-
),
61+
client = katib_api_pb2_grpc.DBManagerStub(channel)
62+
try:
63+
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
64+
client.ReportObservationLog(
65+
request=katib_api_pb2.ReportObservationLogRequest(
66+
trial_name=name,
67+
observation_log=katib_api_pb2.ObservationLog(
68+
metric_logs=[
69+
katib_api_pb2.MetricLog(
70+
time_stamp=timestamp,
71+
metric=katib_api_pb2.Metric(name=name, value=str(value)),
72+
)
73+
for name, value in metrics.items()
74+
]
8075
),
81-
timeout=timeout,
82-
)
83-
except Exception as e:
84-
raise RuntimeError(
85-
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
86-
)
76+
),
77+
timeout=timeout,
78+
)
79+
except Exception as e:
80+
raise RuntimeError(
81+
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
82+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
from kubeflow.katib import report_metrics
5+
from kubeflow.katib.constants import constants
6+
7+
TEST_RESULT_SUCCESS = "success"
8+
ENV_VARIABLE_EMPTY = True
9+
ENV_VARIABLE_NOT_EMPTY = False
10+
11+
12+
def report_observation_log_response(*args, **kwargs):
13+
if kwargs.get("timeout") == 0:
14+
raise TimeoutError
15+
16+
17+
test_report_metrics_data = [
18+
(
19+
"valid metrics with float type",
20+
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
21+
TEST_RESULT_SUCCESS,
22+
ENV_VARIABLE_NOT_EMPTY,
23+
),
24+
(
25+
"valid metrics with string type",
26+
{"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT},
27+
TEST_RESULT_SUCCESS,
28+
ENV_VARIABLE_NOT_EMPTY,
29+
),
30+
(
31+
"valid metrics with int type",
32+
{"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT},
33+
TEST_RESULT_SUCCESS,
34+
ENV_VARIABLE_NOT_EMPTY,
35+
),
36+
(
37+
"ReportObservationLog timeout error",
38+
{"metrics": {"result": 0.99}, "timeout": 0},
39+
RuntimeError,
40+
ENV_VARIABLE_NOT_EMPTY,
41+
),
42+
(
43+
"invalid metrics with type string",
44+
{"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT},
45+
ValueError,
46+
ENV_VARIABLE_NOT_EMPTY,
47+
),
48+
(
49+
"Trial name is not passed to env variables",
50+
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
51+
ValueError,
52+
ENV_VARIABLE_EMPTY,
53+
),
54+
]
55+
56+
57+
@pytest.fixture
58+
def mock_getenv(request):
59+
with patch("os.getenv") as mock:
60+
if request.param is ENV_VARIABLE_EMPTY:
61+
mock.side_effect = ValueError
62+
else:
63+
mock.return_value = "example"
64+
yield mock
65+
66+
67+
@pytest.fixture
68+
def mock_get_current_k8s_namespace():
69+
with patch("kubeflow.katib.utils.utils.get_current_k8s_namespace") as mock:
70+
mock.return_value = "test"
71+
yield mock
72+
73+
74+
@pytest.fixture
75+
def mock_report_observation_log():
76+
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
77+
mock_instance = mock.return_value
78+
mock_instance.ReportObservationLog.side_effect = report_observation_log_response
79+
yield mock_instance
80+
81+
82+
@pytest.mark.parametrize(
83+
"test_name,kwargs,expected_output,mock_getenv",
84+
test_report_metrics_data,
85+
indirect=["mock_getenv"],
86+
)
87+
def test_report_metrics(
88+
test_name,
89+
kwargs,
90+
expected_output,
91+
mock_getenv,
92+
mock_get_current_k8s_namespace,
93+
mock_report_observation_log,
94+
):
95+
"""
96+
test report_metrics function
97+
"""
98+
print("\n\nExecuting test:", test_name)
99+
try:
100+
report_metrics(**kwargs)
101+
assert expected_output == TEST_RESULT_SUCCESS
102+
except Exception as e:
103+
assert type(e) is expected_output
104+
print("test execution complete")

0 commit comments

Comments
 (0)