Skip to content

Commit 9c89dcb

Browse files
committed
Use common util to test Suggestions
1 parent 6a22286 commit 9c89dcb

File tree

3 files changed

+94
-78
lines changed

3 files changed

+94
-78
lines changed

test/suggestion/v1beta1/test_chocolate_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_get_suggestion(self):
184184
self.assertEqual(2, len(response.parameter_assignments))
185185

186186
def test_validate_algorithm_settings(self):
187-
# Valid case
187+
# Valid case.
188188
experiment_spec = api_pb2.ExperimentSpec(
189189
algorithm=api_pb2.AlgorithmSpec(
190190
algorithm_name="grid",
@@ -224,8 +224,8 @@ def test_validate_algorithm_settings(self):
224224
_, _, code, _ = utils.call_validate(self.test_server, experiment_spec)
225225
self.assertEqual(code, grpc.StatusCode.OK)
226226

227-
# Invalid cases
228-
# Empty step
227+
# Invalid cases.
228+
# Empty step.
229229
experiment_spec = api_pb2.ExperimentSpec(
230230
algorithm=api_pb2.AlgorithmSpec(
231231
algorithm_name="grid",
@@ -246,7 +246,7 @@ def test_validate_algorithm_settings(self):
246246
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
247247
self.assertEqual(details, 'Param: param-1 step is nil')
248248

249-
# Max trial count > search space combinations
249+
# Max trial count > search space combinations.
250250
experiment_spec = api_pb2.ExperimentSpec(
251251
algorithm=api_pb2.AlgorithmSpec(
252252
algorithm_name="grid",

test/suggestion/v1beta1/test_hyperopt_service.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class TestHyperopt(unittest.TestCase):
2727
def setUp(self):
2828
servicers = {
29-
api_pb2.DESCRIPTOR.services_by_name['Suggestion']: HyperoptService(
29+
api_pb2.DESCRIPTOR.services_by_name["Suggestion"]: HyperoptService(
3030
)
3131
}
3232

@@ -191,8 +191,8 @@ def test_get_suggestion(self):
191191

192192
get_suggestion = self.test_server.invoke_unary_unary(
193193
method_descriptor=(api_pb2.DESCRIPTOR
194-
.services_by_name['Suggestion']
195-
.methods_by_name['GetSuggestions']),
194+
.services_by_name["Suggestion"]
195+
.methods_by_name["GetSuggestions"]),
196196
invocation_metadata={},
197197
request=request, timeout=1)
198198

@@ -202,7 +202,7 @@ def test_get_suggestion(self):
202202
self.assertEqual(2, len(response.parameter_assignments))
203203

204204
def test_validate_algorithm_settings(self):
205-
# valid cases
205+
# Valid cases.
206206
experiment_spec = api_pb2.ExperimentSpec(
207207
algorithm=api_pb2.AlgorithmSpec(
208208
algorithm_name="tpe",
@@ -230,16 +230,19 @@ def test_validate_algorithm_settings(self):
230230
_, _, code, _ = utils.call_validate(self.test_server, experiment_spec)
231231
self.assertEqual(code, grpc.StatusCode.OK)
232232

233-
# invalid cases
233+
# Invalid cases.
234+
# Unknown algorithm name.
234235
experiment_spec = api_pb2.ExperimentSpec(
235236
algorithm=api_pb2.AlgorithmSpec(
236237
algorithm_name="unknown"
237238
)
238239
)
240+
239241
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
240242
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
241-
self.assertEqual(details, 'unknown algorithm name unknown')
243+
self.assertEqual(details, "unknown algorithm name unknown")
242244

245+
# Unknown algorithm setting name.
243246
experiment_spec = api_pb2.ExperimentSpec(
244247
algorithm=api_pb2.AlgorithmSpec(
245248
algorithm_name="random",
@@ -248,10 +251,12 @@ def test_validate_algorithm_settings(self):
248251
]
249252
)
250253
)
254+
251255
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
252256
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
253-
self.assertEqual(details, 'unknown setting unknown_conf for algorithm random')
257+
self.assertEqual(details, "unknown setting unknown_conf for algorithm random")
254258

259+
# Invalid gamma value.
255260
experiment_spec = api_pb2.ExperimentSpec(
256261
algorithm=api_pb2.AlgorithmSpec(
257262
algorithm_name="tpe",
@@ -260,10 +265,12 @@ def test_validate_algorithm_settings(self):
260265
]
261266
)
262267
)
268+
263269
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
264270
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
265-
self.assertEqual(details, 'gamma should be in the range of (0, 1)')
271+
self.assertEqual(details, "gamma should be in the range of (0, 1)")
266272

273+
# Invalid n_EI_candidates value.
267274
experiment_spec = api_pb2.ExperimentSpec(
268275
algorithm=api_pb2.AlgorithmSpec(
269276
algorithm_name="tpe",
@@ -272,10 +279,12 @@ def test_validate_algorithm_settings(self):
272279
]
273280
)
274281
)
282+
275283
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
276284
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
277-
self.assertEqual(details, 'n_EI_candidates should be great than zero')
285+
self.assertEqual(details, "n_EI_candidates should be great than zero")
278286

287+
# Invalid random_state value.
279288
experiment_spec = api_pb2.ExperimentSpec(
280289
algorithm=api_pb2.AlgorithmSpec(
281290
algorithm_name="tpe",
@@ -284,10 +293,12 @@ def test_validate_algorithm_settings(self):
284293
]
285294
)
286295
)
296+
287297
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
288298
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
289-
self.assertEqual(details, 'random_state should be great or equal than zero')
299+
self.assertEqual(details, "random_state should be great or equal than zero")
290300

301+
# Invalid prior_weight value.
291302
experiment_spec = api_pb2.ExperimentSpec(
292303
algorithm=api_pb2.AlgorithmSpec(
293304
algorithm_name="tpe",
@@ -296,10 +307,11 @@ def test_validate_algorithm_settings(self):
296307
]
297308
)
298309
)
310+
299311
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
300312
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
301-
self.assertTrue(details.startswith('failed to validate prior_weight(aaa)'))
313+
self.assertTrue(details.startswith("failed to validate prior_weight(aaa)"))
302314

303315

304-
if __name__ == '__main__':
316+
if __name__ == "__main__":
305317
unittest.main()

test/suggestion/v1beta1/test_skopt_service.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from pkg.apis.manager.v1beta1.python import api_pb2
2121
from pkg.suggestion.v1beta1.skopt.service import SkoptService
2222

23+
import utils
24+
2325

2426
class TestSkopt(unittest.TestCase):
2527
def setUp(self):
2628
servicers = {
27-
api_pb2.DESCRIPTOR.services_by_name['Suggestion']: SkoptService(
29+
api_pb2.DESCRIPTOR.services_by_name["Suggestion"]: SkoptService(
2830
)
2931
}
3032

@@ -177,8 +179,8 @@ def test_get_suggestion(self):
177179

178180
get_suggestion = self.test_server.invoke_unary_unary(
179181
method_descriptor=(api_pb2.DESCRIPTOR
180-
.services_by_name['Suggestion']
181-
.methods_by_name['GetSuggestions']),
182+
.services_by_name["Suggestion"]
183+
.methods_by_name["GetSuggestions"]),
182184
invocation_metadata={},
183185
request=request, timeout=1)
184186

@@ -188,124 +190,126 @@ def test_get_suggestion(self):
188190
self.assertEqual(2, len(response.parameter_assignments))
189191

190192
def test_validate_algorithm_settings(self):
191-
experiment_spec = [None]
192-
193-
def call_validate():
194-
experiment = api_pb2.Experiment(name="test", spec=experiment_spec[0])
195-
request = api_pb2.ValidateAlgorithmSettingsRequest(experiment=experiment)
196-
197-
validate_algorithm_settings = self.test_server.invoke_unary_unary(
198-
method_descriptor=(api_pb2.DESCRIPTOR
199-
.services_by_name['Suggestion']
200-
.methods_by_name['ValidateAlgorithmSettings']),
201-
invocation_metadata={},
202-
request=request, timeout=1)
193+
# Valid cases.
194+
experiment_spec = api_pb2.ExperimentSpec(
195+
algorithm=api_pb2.AlgorithmSpec(
196+
algorithm_name="bayesianoptimization",
197+
algorithm_settings=[
198+
api_pb2.AlgorithmSetting(
199+
name="random_state",
200+
value="10"
201+
)
202+
],
203+
)
204+
)
203205

204-
return validate_algorithm_settings.termination()
206+
_, _, code, _ = utils.call_validate(self.test_server, experiment_spec)
207+
self.assertEqual(code, grpc.StatusCode.OK)
205208

206-
# valid cases
207-
algorithm_spec = api_pb2.AlgorithmSpec(
208-
algorithm_name="bayesianoptimization",
209-
algorithm_settings=[
210-
api_pb2.AlgorithmSetting(
211-
name="random_state",
212-
value="10"
213-
)
214-
],
209+
# Invalid cases.
210+
# Unknown algorithm name.
211+
experiment_spec = api_pb2.ExperimentSpec(
212+
algorithm=api_pb2.AlgorithmSpec(algorithm_name="unknown")
215213
)
216-
experiment_spec[0] = api_pb2.ExperimentSpec(algorithm=algorithm_spec)
217-
self.assertEqual(call_validate()[2], grpc.StatusCode.OK)
218214

219-
# invalid cases
220-
# unknown algorithm name
221-
experiment_spec[0] = api_pb2.ExperimentSpec(
222-
algorithm=api_pb2.AlgorithmSpec(algorithm_name="unknown"))
223-
_, _, code, details = call_validate()
215+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
224216
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
225-
self.assertEqual(details, 'unknown algorithm name unknown')
217+
self.assertEqual(details, "unknown algorithm name unknown")
226218

227-
# unknown config name
228-
experiment_spec[0] = api_pb2.ExperimentSpec(
219+
# Unknown config name.
220+
experiment_spec = api_pb2.ExperimentSpec(
229221
algorithm=api_pb2.AlgorithmSpec(
230222
algorithm_name="bayesianoptimization",
231223
algorithm_settings=[
232224
api_pb2.AlgorithmSetting(name="unknown_conf", value="1111")]
233-
))
234-
_, _, code, details = call_validate()
225+
)
226+
)
227+
228+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
235229
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
236-
self.assertEqual(details, 'unknown setting unknown_conf for algorithm bayesianoptimization')
230+
self.assertEqual(details, "unknown setting unknown_conf for algorithm bayesianoptimization")
237231

238-
# unknown base_estimator
239-
experiment_spec[0] = api_pb2.ExperimentSpec(
232+
# Unknown base_estimator
233+
experiment_spec = api_pb2.ExperimentSpec(
240234
algorithm=api_pb2.AlgorithmSpec(
241235
algorithm_name="bayesianoptimization",
242236
algorithm_settings=[
243237
api_pb2.AlgorithmSetting(name="base_estimator", value="unknown estimator")]
244-
))
245-
_, _, code, details = call_validate()
246-
wrong_algorithm_setting = experiment_spec[0].algorithm.algorithm_settings[0]
238+
)
239+
)
240+
wrong_algorithm_setting = experiment_spec.algorithm.algorithm_settings[0]
241+
242+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
247243
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
248244
self.assertEqual(details,
249245
"{name} {value} is not supported in Bayesian optimization".format(
250246
name=wrong_algorithm_setting.name,
251247
value=wrong_algorithm_setting.value))
252248

253-
# wrong n_initial_points
254-
experiment_spec[0] = api_pb2.ExperimentSpec(
249+
# Wrong n_initial_points
250+
experiment_spec = api_pb2.ExperimentSpec(
255251
algorithm=api_pb2.AlgorithmSpec(
256252
algorithm_name="bayesianoptimization",
257253
algorithm_settings=[
258254
api_pb2.AlgorithmSetting(name="n_initial_points", value="-1")]
259-
))
260-
_, _, code, details = call_validate()
261-
wrong_algorithm_setting = experiment_spec[0].algorithm.algorithm_settings[0]
255+
)
256+
)
257+
wrong_algorithm_setting = experiment_spec.algorithm.algorithm_settings[0]
258+
259+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
262260
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
263261
self.assertEqual(details, "{name} should be great or equal than zero".format(name=wrong_algorithm_setting.name))
264262

265-
# unknown acq_func
266-
experiment_spec[0] = api_pb2.ExperimentSpec(
263+
# Unknown acq_func
264+
experiment_spec = api_pb2.ExperimentSpec(
267265
algorithm=api_pb2.AlgorithmSpec(
268266
algorithm_name="bayesianoptimization",
269267
algorithm_settings=[
270268
api_pb2.AlgorithmSetting(name="acq_func", value="unknown")]
271-
))
272-
_, _, code, details = call_validate()
269+
)
270+
)
273271
wrong_algorithm_setting = experiment_spec[0].algorithm.algorithm_settings[0]
272+
273+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
274274
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
275275
self.assertEqual(details,
276276
"{name} {value} is not supported in Bayesian optimization".format(
277277
name=wrong_algorithm_setting.name,
278278
value=wrong_algorithm_setting.value
279279
))
280280

281-
# unknown acq_optimizer
282-
experiment_spec[0] = api_pb2.ExperimentSpec(
281+
# Unknown acq_optimizer
282+
experiment_spec = api_pb2.ExperimentSpec(
283283
algorithm=api_pb2.AlgorithmSpec(
284284
algorithm_name="bayesianoptimization",
285285
algorithm_settings=[
286286
api_pb2.AlgorithmSetting(name="acq_optimizer", value="unknown")]
287-
))
288-
_, _, code, details = call_validate()
287+
)
288+
)
289289
wrong_algorithm_setting = experiment_spec[0].algorithm.algorithm_settings[0]
290+
291+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
290292
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
291293
self.assertEqual(details,
292294
"{name} {value} is not supported in Bayesian optimization".format(
293295
name=wrong_algorithm_setting.name,
294296
value=wrong_algorithm_setting.value
295297
))
296298

297-
# wrong random_state
298-
experiment_spec[0] = api_pb2.ExperimentSpec(
299+
# Wrong random_state
300+
experiment_spec = api_pb2.ExperimentSpec(
299301
algorithm=api_pb2.AlgorithmSpec(
300302
algorithm_name="bayesianoptimization",
301303
algorithm_settings=[
302304
api_pb2.AlgorithmSetting(name="random_state", value="-1")]
303-
))
304-
_, _, code, details = call_validate()
305-
wrong_algorithm_setting = experiment_spec[0].algorithm.algorithm_settings[0]
305+
)
306+
)
307+
wrong_algorithm_setting = experiment_spec.algorithm.algorithm_settings[0]
308+
309+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
306310
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
307311
self.assertEqual(details, "{name} should be great or equal than zero".format(name=wrong_algorithm_setting.name))
308312

309313

310-
if __name__ == '__main__':
314+
if __name__ == "__main__":
311315
unittest.main()

0 commit comments

Comments
 (0)