Skip to content

Commit 85dc9a1

Browse files
committed
added test for create_experiment in katib_client
Signed-off-by: tariq-hasan <[email protected]>
1 parent 7c03cb4 commit 85dc9a1

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ pytest: prepare-pytest prepare-pytest-testdata
178178
PYTHONPATH=$(PYTHONPATH) pytest ./test/unit/v1beta1/suggestion --ignore=./test/unit/v1beta1/suggestion/test_skopt_service.py
179179
PYTHONPATH=$(PYTHONPATH) pytest ./test/unit/v1beta1/earlystopping
180180
PYTHONPATH=$(PYTHONPATH) pytest ./test/unit/v1beta1/metricscollector
181+
cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
182+
PYTHONPATH=$(PYTHONPATH) pytest ./sdk/python/v1beta1/kubeflow/katib
183+
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
181184

182185
# The skopt service doesn't work appropriately with Python 3.11.
183186
# So, we need to run the test with Python 3.9.
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import multiprocessing
2+
from typing import List, Optional
3+
from unittest.mock import patch, Mock
4+
5+
import pytest
6+
from kubernetes.client import V1ObjectMeta
7+
8+
from kubeflow.katib import KatibClient
9+
from kubeflow.katib import V1beta1AlgorithmSpec
10+
from kubeflow.katib import V1beta1Experiment
11+
from kubeflow.katib import V1beta1ExperimentSpec
12+
from kubeflow.katib import V1beta1FeasibleSpace
13+
from kubeflow.katib import V1beta1ObjectiveSpec
14+
from kubeflow.katib import V1beta1ParameterSpec
15+
from kubeflow.katib import V1beta1TrialParameterSpec
16+
from kubeflow.katib import V1beta1TrialTemplate
17+
from kubeflow.katib.constants import constants
18+
19+
20+
class ConflictException(Exception):
21+
def __init__(self):
22+
self.status = 409
23+
24+
25+
def create_namespaced_custom_object_response(*args, **kwargs):
26+
if args[2] == "timeout":
27+
raise multiprocessing.TimeoutError()
28+
elif args[2] == "conflict":
29+
raise ConflictException()
30+
elif args[2] == "runtime":
31+
raise Exception()
32+
elif args[2] in ("test", "test-name"):
33+
return {"metadata": {"name": "experiment-mnist-ci-test"}}
34+
elif args[2] == "test-generate-name":
35+
return {"metadata": {"name": "12345-experiment-mnist-ci-test"}}
36+
37+
38+
def generate_trial_template() -> V1beta1TrialTemplate:
39+
trial_spec={
40+
"apiVersion": "batch/v1",
41+
"kind": "Job",
42+
"spec": {
43+
"template": {
44+
"metadata": {
45+
"annotations": {
46+
"sidecar.istio.io/inject": "false"
47+
}
48+
},
49+
"spec": {
50+
"containers": [
51+
{
52+
"name": "training-container",
53+
"image": "docker.io/kubeflowkatib/pytorch-mnist-cpu:v0.14.0",
54+
"command": [
55+
"python3",
56+
"/opt/pytorch-mnist/mnist.py",
57+
"--epochs=1",
58+
"--batch-size=64",
59+
"--lr=${trialParameters.learningRate}",
60+
"--momentum=${trialParameters.momentum}",
61+
]
62+
}
63+
],
64+
"restartPolicy": "Never"
65+
}
66+
}
67+
}
68+
}
69+
70+
return V1beta1TrialTemplate(
71+
primary_container_name="training-container",
72+
trial_parameters=[
73+
V1beta1TrialParameterSpec(
74+
name="learningRate",
75+
description="Learning rate for the training model",
76+
reference="lr"
77+
),
78+
V1beta1TrialParameterSpec(
79+
name="momentum",
80+
description="Momentum for the training model",
81+
reference="momentum"
82+
),
83+
],
84+
trial_spec=trial_spec
85+
)
86+
87+
88+
def generate_experiment(
89+
metadata: V1ObjectMeta,
90+
algorithm_spec: V1beta1AlgorithmSpec,
91+
objective_spec: V1beta1ObjectiveSpec,
92+
parameters: List[V1beta1ParameterSpec],
93+
trial_template: V1beta1TrialTemplate,
94+
) -> V1beta1Experiment:
95+
return V1beta1Experiment(
96+
api_version=constants.API_VERSION,
97+
kind=constants.EXPERIMENT_KIND,
98+
metadata=metadata,
99+
spec=V1beta1ExperimentSpec(
100+
max_trial_count=3,
101+
parallel_trial_count=2,
102+
max_failed_trial_count=1,
103+
algorithm=algorithm_spec,
104+
objective=objective_spec,
105+
parameters=parameters,
106+
trial_template=trial_template,
107+
)
108+
)
109+
110+
111+
def create_experiment(
112+
name: Optional[str] = None,
113+
generate_name: Optional[str] = None
114+
) -> V1beta1Experiment:
115+
experiment_namespace = "test"
116+
117+
if name is not None:
118+
metadata = V1ObjectMeta(name=name, namespace=experiment_namespace)
119+
elif generate_name is not None:
120+
metadata = V1ObjectMeta(generate_name=generate_name, namespace=experiment_namespace)
121+
else:
122+
metadata = V1ObjectMeta(namespace=experiment_namespace)
123+
124+
algorithm_spec=V1beta1AlgorithmSpec(
125+
algorithm_name="random"
126+
)
127+
128+
objective_spec=V1beta1ObjectiveSpec(
129+
type="minimize",
130+
goal= 0.001,
131+
objective_metric_name="loss",
132+
)
133+
134+
parameters=[
135+
V1beta1ParameterSpec(
136+
name="lr",
137+
parameter_type="double",
138+
feasible_space=V1beta1FeasibleSpace(
139+
min="0.01",
140+
max="0.06"
141+
),
142+
),
143+
V1beta1ParameterSpec(
144+
name="momentum",
145+
parameter_type="double",
146+
feasible_space=V1beta1FeasibleSpace(
147+
min="0.5",
148+
max="0.9"
149+
),
150+
),
151+
]
152+
153+
trial_template = generate_trial_template()
154+
155+
experiment = generate_experiment(
156+
metadata,
157+
algorithm_spec,
158+
objective_spec,
159+
parameters,
160+
trial_template
161+
)
162+
return experiment
163+
164+
165+
test_create_experiment_data = [
166+
(
167+
"experiment name and generate_name missing",
168+
{"experiment": create_experiment()},
169+
ValueError,
170+
),
171+
(
172+
"create_namespaced_custom_object timeout error",
173+
{
174+
"experiment": create_experiment(name="experiment-mnist-ci-test"),
175+
"namespace": "timeout",
176+
},
177+
TimeoutError,
178+
),
179+
(
180+
"create_namespaced_custom_object conflict error",
181+
{
182+
"experiment": create_experiment(name="experiment-mnist-ci-test"),
183+
"namespace": "conflict",
184+
},
185+
Exception,
186+
),
187+
(
188+
"create_namespaced_custom_object runtime error",
189+
{
190+
"experiment": create_experiment(name="experiment-mnist-ci-test"),
191+
"namespace": "runtime",
192+
},
193+
RuntimeError,
194+
),
195+
(
196+
"valid flow with experiment type V1beta1Experiment and name",
197+
{
198+
"experiment": create_experiment(name="experiment-mnist-ci-test"),
199+
"namespace": "test-name",
200+
},
201+
constants.TEST_RESULT_SUCCESS,
202+
),
203+
(
204+
"valid flow with experiment type V1beta1Experiment and generate_name",
205+
{
206+
"experiment": create_experiment(generate_name="experiment-mnist-ci-test"),
207+
"namespace": "test-generate-name",
208+
},
209+
constants.TEST_RESULT_SUCCESS,
210+
),
211+
(
212+
"valid flow with experiment JSON and name",
213+
{
214+
"experiment": {
215+
"metadata": {
216+
"name": "experiment-mnist-ci-test",
217+
}
218+
},
219+
"namespace": "test-name",
220+
},
221+
constants.TEST_RESULT_SUCCESS,
222+
),
223+
(
224+
"valid flow with experiment JSON and generate_name",
225+
{
226+
"experiment": {
227+
"metadata": {
228+
"generate_name": "experiment-mnist-ci-test",
229+
}
230+
},
231+
"namespace": "test-generate-name",
232+
},
233+
constants.TEST_RESULT_SUCCESS,
234+
),
235+
]
236+
237+
238+
@pytest.fixture
239+
def katib_client():
240+
with patch(
241+
"kubernetes.client.CustomObjectsApi",
242+
return_value=Mock(
243+
create_namespaced_custom_object=Mock(
244+
side_effect=create_namespaced_custom_object_response
245+
)
246+
),
247+
), patch(
248+
"kubernetes.config.load_kube_config",
249+
return_value=Mock()
250+
):
251+
client = KatibClient()
252+
yield client
253+
254+
255+
@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data)
256+
def test_create_experiment(katib_client, test_name, kwargs, expected_output):
257+
"""
258+
test create_experiment function of katib client
259+
"""
260+
print("\n\nExecuting test:", test_name)
261+
try:
262+
katib_client.create_experiment(**kwargs)
263+
assert expected_output == constants.TEST_RESULT_SUCCESS
264+
except Exception as e:
265+
assert type(e) is expected_output
266+
print("test execution complete")

sdk/python/v1beta1/kubeflow/katib/constants/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@
5757
BASE_IMAGE_MXNET = "docker.io/mxnet/python:1.9.1_native_py3"
5858

5959
DEFAULT_DB_MANAGER_ADDRESS = "katib-db-manager.kubeflow:6789"
60+
61+
# Test result constants
62+
TEST_RESULT_SUCCESS = "success"

0 commit comments

Comments
 (0)