Skip to content

Commit 85d034f

Browse files
feat: add UTs for get_trial_metrics.
Signed-off-by: Electronic-Waste <[email protected]>
1 parent ad7de0a commit 85d034f

File tree

2 files changed

+95
-11
lines changed

2 files changed

+95
-11
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from kubeflow.katib import V1beta1TrialParameterSpec
1414
from kubeflow.katib import V1beta1TrialTemplate
1515
from kubeflow.katib.constants import constants
16+
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
1617
from kubernetes.client import V1ObjectMeta
1718
import pytest
1819

@@ -238,7 +239,7 @@ def create_experiment(
238239

239240

240241
@pytest.fixture
241-
def katib_client():
242+
def katib_client_create_experiment():
242243
with patch(
243244
"kubernetes.client.CustomObjectsApi",
244245
return_value=Mock(
@@ -255,14 +256,103 @@ def katib_client():
255256

256257

257258
@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data)
258-
def test_create_experiment(katib_client, test_name, kwargs, expected_output):
259+
def test_create_experiment(katib_client_create_experiment, test_name, kwargs, expected_output):
259260
"""
260261
test create_experiment function of katib client
261262
"""
262263
print("\n\nExecuting test:", test_name)
263264
try:
264-
katib_client.create_experiment(**kwargs)
265+
katib_client_create_experiment.create_experiment(**kwargs)
265266
assert expected_output == TEST_RESULT_SUCCESS
266267
except Exception as e:
267268
assert type(e) is expected_output
268269
print("test execution complete")
270+
271+
272+
def get_observation_log_response(*args, **kwargs):
273+
if kwargs.get("timeout") == 0:
274+
raise TimeoutError
275+
elif args[0].trial_name == "invalid":
276+
raise RuntimeError
277+
else:
278+
return katib_api_pb2.GetObservationLogReply(
279+
observation_log=katib_api_pb2.ObservationLog(
280+
metric_logs=[
281+
katib_api_pb2.MetricLog(
282+
time_stamp="2024-07-29T15:09:08Z",
283+
metric=katib_api_pb2.Metric(name="result",value="0.99")
284+
)
285+
]
286+
)
287+
)
288+
289+
test_get_trial_metrics_data = [
290+
(
291+
"valid trial name",
292+
{
293+
"name": "example",
294+
"namespace": "valid",
295+
"timeout": constants.DEFAULT_TIMEOUT
296+
},
297+
[
298+
katib_api_pb2.MetricLog(
299+
time_stamp="2024-07-29T15:09:08Z",
300+
metric=katib_api_pb2.Metric(name="result",value="0.99")
301+
)
302+
]
303+
),
304+
(
305+
"invalid trial name",
306+
{
307+
"name": "invalid",
308+
"namespace": "invalid",
309+
"timeout": constants.DEFAULT_TIMEOUT
310+
},
311+
RuntimeError
312+
),
313+
(
314+
"GetObservationLog timeout error",
315+
{
316+
"name": "example",
317+
"namespace": "valid",
318+
"timeout": 0
319+
},
320+
RuntimeError
321+
)
322+
]
323+
324+
325+
@pytest.fixture
326+
def katib_client_get_trial_metrics():
327+
with patch(
328+
"kubernetes.client.CustomObjectsApi",
329+
return_value=Mock(),
330+
), patch(
331+
"kubernetes.config.load_kube_config",
332+
return_value=Mock()
333+
):
334+
client = KatibClient()
335+
yield client
336+
337+
338+
@pytest.fixture
339+
def mock_get_observation_log():
340+
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
341+
mock_instance = mock.return_value
342+
mock_instance.GetObservationLog.side_effect = get_observation_log_response
343+
yield mock_instance
344+
345+
346+
@pytest.mark.parametrize("test_name,kwargs,expected_output", test_get_trial_metrics_data)
347+
def test_get_trial_metrics(test_name, kwargs, expected_output, katib_client_get_trial_metrics, mock_get_observation_log):
348+
"""
349+
test get_trial_metrics function of katib client
350+
"""
351+
print("\n\nExecuting test:", test_name)
352+
try:
353+
metrics = katib_client_get_trial_metrics.get_trial_metrics(**kwargs)
354+
for i in range(len(metrics)):
355+
assert metrics[i] == expected_output[i]
356+
except Exception as e:
357+
assert type(e) is expected_output
358+
print("test execution complete")

sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def report_observation_log_response(*args, **kwargs):
2121
"metrics": {
2222
"result": 0.99
2323
},
24-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
2524
"timeout": constants.DEFAULT_TIMEOUT
2625

2726
},
@@ -34,7 +33,6 @@ def report_observation_log_response(*args, **kwargs):
3433
"metrics": {
3534
"result": "0.99"
3635
},
37-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
3836
"timeout": constants.DEFAULT_TIMEOUT
3937
},
4038
TEST_RESULT_SUCCESS,
@@ -46,7 +44,6 @@ def report_observation_log_response(*args, **kwargs):
4644
"metrics": {
4745
"result": 1
4846
},
49-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
5047
"timeout": constants.DEFAULT_TIMEOUT
5148
},
5249
TEST_RESULT_SUCCESS,
@@ -58,7 +55,6 @@ def report_observation_log_response(*args, **kwargs):
5855
"metrics": {
5956
"result": 0.99
6057
},
61-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
6258
"timeout": 0
6359
},
6460
RuntimeError,
@@ -70,7 +66,6 @@ def report_observation_log_response(*args, **kwargs):
7066
"metrics": {
7167
"result": "abc"
7268
},
73-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
7469
"timeout": constants.DEFAULT_TIMEOUT
7570
},
7671
ValueError,
@@ -82,7 +77,6 @@ def report_observation_log_response(*args, **kwargs):
8277
"metrics": {
8378
"result": 0.99
8479
},
85-
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
8680
"timeout": constants.DEFAULT_TIMEOUT
8781
},
8882
ValueError,
@@ -117,7 +111,7 @@ def mock_report_observation_log():
117111

118112

119113
@pytest.mark.parametrize(
120-
"test_name, kwargs, expected_output, mock_getenv",
114+
"test_name,kwargs,expected_output,mock_getenv",
121115
test_report_metrics_data,
122116
indirect=["mock_getenv"]
123117
)
@@ -131,4 +125,4 @@ def test_report_metrics(test_name, kwargs, expected_output, mock_getenv, mock_ge
131125
assert expected_output == TEST_RESULT_SUCCESS
132126
except Exception as e:
133127
assert type(e) is expected_output
134-
print("test execution complete")
128+
print("test execution complete")

0 commit comments

Comments
 (0)