Skip to content

Commit be2b26d

Browse files
authored
Validate possible operations for Grid suggestion (#1205)
* Create common function to test validate algorithm settings Validate db exhausted for chocolate * remove parentheses * Use common util to test Suggestions * Fix API name * Fix indexing
1 parent c24d303 commit be2b26d

File tree

5 files changed

+297
-137
lines changed

5 files changed

+297
-137
lines changed

pkg/suggestion/v1beta1/chocolate/service.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
from pkg.apis.manager.v1beta1.python import api_pb2
1919
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
2020

21-
from pkg.suggestion.v1beta1.internal.constant import DOUBLE
21+
from pkg.suggestion.v1beta1.internal.constant import INTEGER, DOUBLE, CATEGORICAL, DISCRETE
2222
from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
2323
from pkg.suggestion.v1beta1.internal.trial import Trial, Assignment
2424
from pkg.suggestion.v1beta1.chocolate.base_service import BaseChocolateService
2525
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
2626

27+
import numpy as np
28+
import itertools
29+
2730
logger = logging.getLogger(__name__)
2831

2932

@@ -38,11 +41,32 @@ def ValidateAlgorithmSettings(self, request, context):
3841
if algorithm_name == "grid":
3942
search_space = HyperParameterSearchSpace.convert(
4043
request.experiment)
44+
available_space = {}
4145
for param in search_space.params:
42-
if param.type == DOUBLE:
46+
if param.type == INTEGER:
47+
available_space[param.name] = range(int(param.min), int(param.max)+1, int(param.step))
48+
49+
elif param.type == DOUBLE:
4350
if param.step == "" or param.step is None:
4451
return self._set_validate_context_error(
45-
context, "param {} step is nil".format(param.name))
52+
context, "Param: {} step is nil".format(param.name))
53+
double_list = np.arange(float(param.min), float(param.max)+float(param.step), float(param.step))
54+
if double_list[-1] > float(param.max):
55+
double_list = double_list[:-1]
56+
available_space[param.name] = double_list
57+
58+
elif param.type == CATEGORICAL or param.type == DISCRETE:
59+
available_space[param.name] = param.list
60+
61+
num_combinations = len(list(itertools.product(*available_space.values())))
62+
max_trial_count = request.experiment.spec.max_trial_count
63+
64+
if max_trial_count > num_combinations:
65+
return self._set_validate_context_error(
66+
context, "Max Trial Count: {} > all possible search space combinations: {}".format(
67+
max_trial_count, num_combinations)
68+
)
69+
4670
return api_pb2.ValidateAlgorithmSettingsReply()
4771

4872
def GetSuggestions(self, request, context):

test/suggestion/v1beta1/test_chocolate_service.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from pkg.suggestion.v1beta1.chocolate.service import ChocolateService
2222

23+
import utils
24+
2325

2426
class TestChocolate(unittest.TestCase):
2527
def setUp(self):
@@ -181,6 +183,103 @@ def test_get_suggestion(self):
181183
self.assertEqual(code, grpc.StatusCode.OK)
182184
self.assertEqual(2, len(response.parameter_assignments))
183185

186+
def test_validate_algorithm_settings(self):
187+
# Valid case.
188+
experiment_spec = api_pb2.ExperimentSpec(
189+
algorithm=api_pb2.AlgorithmSpec(
190+
algorithm_name="grid",
191+
),
192+
parameter_specs=api_pb2.ExperimentSpec.ParameterSpecs(
193+
parameters=[
194+
api_pb2.ParameterSpec(
195+
name="param-1",
196+
parameter_type=api_pb2.INT,
197+
feasible_space=api_pb2.FeasibleSpace(
198+
max="5", min="1", list=[]),
199+
),
200+
api_pb2.ParameterSpec(
201+
name="param-2",
202+
parameter_type=api_pb2.CATEGORICAL,
203+
feasible_space=api_pb2.FeasibleSpace(
204+
max=None, min=None, list=["cat1", "cat2", "cat3"])
205+
),
206+
api_pb2.ParameterSpec(
207+
name="param-3",
208+
parameter_type=api_pb2.DISCRETE,
209+
feasible_space=api_pb2.FeasibleSpace(
210+
max=None, min=None, list=["3", "2", "6"])
211+
),
212+
api_pb2.ParameterSpec(
213+
name="param-4",
214+
parameter_type=api_pb2.DOUBLE,
215+
feasible_space=api_pb2.FeasibleSpace(
216+
max="2.9", min="1", list=[], step="0.5")
217+
)
218+
]
219+
),
220+
max_trial_count=12,
221+
parallel_trial_count=3,
222+
)
223+
224+
_, _, code, _ = utils.call_validate(self.test_server, experiment_spec)
225+
self.assertEqual(code, grpc.StatusCode.OK)
226+
227+
# Invalid cases.
228+
# Empty step.
229+
experiment_spec = api_pb2.ExperimentSpec(
230+
algorithm=api_pb2.AlgorithmSpec(
231+
algorithm_name="grid",
232+
),
233+
parameter_specs=api_pb2.ExperimentSpec.ParameterSpecs(
234+
parameters=[
235+
api_pb2.ParameterSpec(
236+
name="param-1",
237+
parameter_type=api_pb2.DOUBLE,
238+
feasible_space=api_pb2.FeasibleSpace(
239+
max="3", min="1", list=[])
240+
)
241+
]
242+
),
243+
)
244+
245+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
246+
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
247+
self.assertEqual(details, 'Param: param-1 step is nil')
248+
249+
# Max trial count > search space combinations.
250+
experiment_spec = api_pb2.ExperimentSpec(
251+
algorithm=api_pb2.AlgorithmSpec(
252+
algorithm_name="grid",
253+
),
254+
parameter_specs=api_pb2.ExperimentSpec.ParameterSpecs(
255+
parameters=[
256+
api_pb2.ParameterSpec(
257+
name="param-1",
258+
parameter_type=api_pb2.INT,
259+
feasible_space=api_pb2.FeasibleSpace(
260+
max="2", min="1", list=[]),
261+
),
262+
api_pb2.ParameterSpec(
263+
name="param-2",
264+
parameter_type=api_pb2.CATEGORICAL,
265+
feasible_space=api_pb2.FeasibleSpace(
266+
max=None, min=None, list=["cat1", "cat2"])
267+
),
268+
api_pb2.ParameterSpec(
269+
name="param-4",
270+
parameter_type=api_pb2.DOUBLE,
271+
feasible_space=api_pb2.FeasibleSpace(
272+
max="2", min="1", list=[], step="0.5")
273+
)
274+
]
275+
),
276+
max_trial_count=15,
277+
)
278+
279+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
280+
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
281+
self.assertEqual(details, 'Max Trial Count: 15 > all possible search space combinations: 12')
282+
184283

185284
if __name__ == '__main__':
186285
unittest.main()

test/suggestion/v1beta1/test_hyperopt_service.py

Lines changed: 85 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from pkg.apis.manager.v1beta1.python import api_pb2
2121
from pkg.suggestion.v1beta1.hyperopt.service import HyperoptService
2222

23+
import utils
24+
2325

2426
class TestHyperopt(unittest.TestCase):
2527
def setUp(self):
2628
servicers = {
27-
api_pb2.DESCRIPTOR.services_by_name['Suggestion']: HyperoptService(
29+
api_pb2.DESCRIPTOR.services_by_name["Suggestion"]: HyperoptService(
2830
)
2931
}
3032

@@ -189,8 +191,8 @@ def test_get_suggestion(self):
189191

190192
get_suggestion = self.test_server.invoke_unary_unary(
191193
method_descriptor=(api_pb2.DESCRIPTOR
192-
.services_by_name['Suggestion']
193-
.methods_by_name['GetSuggestions']),
194+
.services_by_name["Suggestion"]
195+
.methods_by_name["GetSuggestions"]),
194196
invocation_metadata={},
195197
request=request, timeout=1)
196198

@@ -200,103 +202,116 @@ def test_get_suggestion(self):
200202
self.assertEqual(2, len(response.parameter_assignments))
201203

202204
def test_validate_algorithm_settings(self):
203-
experiment_spec = [None]
204-
205-
def call_validate():
206-
experiment = api_pb2.Experiment(name="test", spec=experiment_spec[0])
207-
request = api_pb2.ValidateAlgorithmSettingsRequest(experiment=experiment)
208-
209-
validate_algorithm_settings = self.test_server.invoke_unary_unary(
210-
method_descriptor=(api_pb2.DESCRIPTOR
211-
.services_by_name['Suggestion']
212-
.methods_by_name['ValidateAlgorithmSettings']),
213-
invocation_metadata={},
214-
request=request, timeout=1)
205+
# Valid cases.
206+
experiment_spec = api_pb2.ExperimentSpec(
207+
algorithm=api_pb2.AlgorithmSpec(
208+
algorithm_name="tpe",
209+
algorithm_settings=[
210+
api_pb2.AlgorithmSetting(
211+
name="random_state",
212+
value="10"
213+
),
214+
api_pb2.AlgorithmSetting(
215+
name="gamma",
216+
value="0.25"
217+
),
218+
api_pb2.AlgorithmSetting(
219+
name="prior_weight",
220+
value="1.0"
221+
),
222+
api_pb2.AlgorithmSetting(
223+
name="n_EI_candidates",
224+
value="24"
225+
),
226+
]
227+
)
228+
)
215229

216-
return validate_algorithm_settings.termination()
230+
_, _, code, _ = utils.call_validate(self.test_server, experiment_spec)
231+
self.assertEqual(code, grpc.StatusCode.OK)
217232

218-
# valid cases
219-
algorithm_spec = api_pb2.AlgorithmSpec(
220-
algorithm_name="tpe",
221-
algorithm_settings=[
222-
api_pb2.AlgorithmSetting(
223-
name="random_state",
224-
value="10"
225-
),
226-
api_pb2.AlgorithmSetting(
227-
name="gamma",
228-
value="0.25"
229-
),
230-
api_pb2.AlgorithmSetting(
231-
name="prior_weight",
232-
value="1.0"
233-
),
234-
api_pb2.AlgorithmSetting(
235-
name="n_EI_candidates",
236-
value="24"
237-
),
238-
],
233+
# Invalid cases.
234+
# Unknown algorithm name.
235+
experiment_spec = api_pb2.ExperimentSpec(
236+
algorithm=api_pb2.AlgorithmSpec(
237+
algorithm_name="unknown"
238+
)
239239
)
240-
experiment_spec[0] = api_pb2.ExperimentSpec(algorithm=algorithm_spec)
241-
self.assertEqual(call_validate()[2], grpc.StatusCode.OK)
242240

243-
# invalid cases
244-
experiment_spec[0] = api_pb2.ExperimentSpec(
245-
algorithm=api_pb2.AlgorithmSpec(algorithm_name="unknown"))
246-
_, _, code, details = call_validate()
241+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
247242
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
248-
self.assertEqual(details, 'unknown algorithm name unknown')
243+
self.assertEqual(details, "unknown algorithm name unknown")
249244

250-
experiment_spec[0] = api_pb2.ExperimentSpec(
245+
# Unknown algorithm setting name.
246+
experiment_spec = api_pb2.ExperimentSpec(
251247
algorithm=api_pb2.AlgorithmSpec(
252248
algorithm_name="random",
253249
algorithm_settings=[
254-
api_pb2.AlgorithmSetting(name="unknown_conf", value="1111")]
255-
))
256-
_, _, code, details = call_validate()
250+
api_pb2.AlgorithmSetting(name="unknown_conf", value="1111")
251+
]
252+
)
253+
)
254+
255+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
257256
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
258-
self.assertEqual(details, 'unknown setting unknown_conf for algorithm random')
257+
self.assertEqual(details, "unknown setting unknown_conf for algorithm random")
259258

260-
experiment_spec[0] = api_pb2.ExperimentSpec(
259+
# Invalid gamma value.
260+
experiment_spec = api_pb2.ExperimentSpec(
261261
algorithm=api_pb2.AlgorithmSpec(
262262
algorithm_name="tpe",
263263
algorithm_settings=[
264-
api_pb2.AlgorithmSetting(name="gamma", value="1.5")]
265-
))
266-
_, _, code, details = call_validate()
264+
api_pb2.AlgorithmSetting(name="gamma", value="1.5")
265+
]
266+
)
267+
)
268+
269+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
267270
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
268-
self.assertEqual(details, 'gamma should be in the range of (0, 1)')
271+
self.assertEqual(details, "gamma should be in the range of (0, 1)")
269272

270-
experiment_spec[0] = api_pb2.ExperimentSpec(
273+
# Invalid n_EI_candidates value.
274+
experiment_spec = api_pb2.ExperimentSpec(
271275
algorithm=api_pb2.AlgorithmSpec(
272276
algorithm_name="tpe",
273277
algorithm_settings=[
274-
api_pb2.AlgorithmSetting(name="n_EI_candidates", value="0")]
275-
))
276-
_, _, code, details = call_validate()
278+
api_pb2.AlgorithmSetting(name="n_EI_candidates", value="0")
279+
]
280+
)
281+
)
282+
283+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
277284
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
278-
self.assertEqual(details, 'n_EI_candidates should be great than zero')
285+
self.assertEqual(details, "n_EI_candidates should be great than zero")
279286

280-
experiment_spec[0] = api_pb2.ExperimentSpec(
287+
# Invalid random_state value.
288+
experiment_spec = api_pb2.ExperimentSpec(
281289
algorithm=api_pb2.AlgorithmSpec(
282290
algorithm_name="tpe",
283291
algorithm_settings=[
284-
api_pb2.AlgorithmSetting(name="random_state", value="-1")]
285-
))
286-
_, _, code, details = call_validate()
292+
api_pb2.AlgorithmSetting(name="random_state", value="-1")
293+
]
294+
)
295+
)
296+
297+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
287298
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
288-
self.assertEqual(details, 'random_state should be great or equal than zero')
299+
self.assertEqual(details, "random_state should be great or equal than zero")
289300

290-
experiment_spec[0] = api_pb2.ExperimentSpec(
301+
# Invalid prior_weight value.
302+
experiment_spec = api_pb2.ExperimentSpec(
291303
algorithm=api_pb2.AlgorithmSpec(
292304
algorithm_name="tpe",
293305
algorithm_settings=[
294-
api_pb2.AlgorithmSetting(name="prior_weight", value="aaa")]
295-
))
296-
_, _, code, details = call_validate()
306+
api_pb2.AlgorithmSetting(name="prior_weight", value="aaa")
307+
]
308+
)
309+
)
310+
311+
_, _, code, details = utils.call_validate(self.test_server, experiment_spec)
297312
self.assertEqual(code, grpc.StatusCode.INVALID_ARGUMENT)
298-
self.assertTrue(details.startswith('failed to validate prior_weight(aaa)'))
313+
self.assertTrue(details.startswith("failed to validate prior_weight(aaa)"))
299314

300315

301-
if __name__ == '__main__':
316+
if __name__ == "__main__":
302317
unittest.main()

0 commit comments

Comments
 (0)