Skip to content

Commit 2cb4d86

Browse files
committed
Support for grid search algorithm in Optuna Suggestion Service
Signed-off-by: Yuki Iwai <[email protected]>
1 parent db72ce1 commit 2cb4d86

File tree

5 files changed

+162
-274
lines changed

5 files changed

+162
-274
lines changed

manifests/v1beta1/components/controller/katib-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ data:
3131
"image": "docker.io/kubeflowkatib/suggestion-hyperopt:latest"
3232
},
3333
"grid": {
34-
"image": "docker.io/kubeflowkatib/suggestion-chocolate:latest"
34+
"image": "docker.io/kubeflowkatib/suggestion-optuna:latest"
3535
},
3636
"hyperband": {
3737
"image": "docker.io/kubeflowkatib/suggestion-hyperband:latest"

pkg/suggestion/v1beta1/internal/search_space.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import numpy as np
1617

1718
from pkg.apis.manager.v1beta1.python import api_pb2 as api
19+
from pkg.suggestion.v1beta1.internal.constant import INTEGER, DOUBLE, CATEGORICAL, DISCRETE
1820
import pkg.suggestion.v1beta1.internal.constant as constant
1921

20-
2122
logging.basicConfig(level=logging.DEBUG)
2223
logger = logging.getLogger(__name__)
2324

@@ -36,15 +37,38 @@ def convert(experiment):
3637
search_space.goal = constant.MIN_GOAL
3738
for p in experiment.spec.parameter_specs.parameters:
3839
search_space.params.append(
39-
HyperParameterSearchSpace.convertParameter(p))
40+
HyperParameterSearchSpace.convert_parameter(p))
4041
return search_space
4142

43+
@staticmethod
44+
def convert_to_combinations(search_space):
45+
combinations = {}
46+
47+
for parameter in search_space.params:
48+
if parameter.type == INTEGER:
49+
combinations[parameter.name] = range(int(parameter.min), int(parameter.max)+1, int(parameter.step))
50+
elif parameter.type == DOUBLE:
51+
if parameter.step == "" or parameter.step is None:
52+
raise Exception(
53+
"Param {} step is nil; For discrete search space, all parameters must include step".
54+
format(parameter.name)
55+
)
56+
double_list = np.arange(float(parameter.min), float(parameter.max)+float(parameter.step),
57+
float(parameter.step))
58+
if double_list[-1] > float(parameter.max):
59+
double_list = double_list[:-1]
60+
combinations[parameter.name] = double_list
61+
elif parameter.type == CATEGORICAL or parameter.type == DISCRETE:
62+
combinations[parameter.name] = parameter.list
63+
64+
return combinations
65+
4266
def __str__(self):
4367
return "HyperParameterSearchSpace(goal: {}, ".format(self.goal) + \
4468
"params: {})".format(", ".join([element.__str__() for element in self.params]))
4569

4670
@staticmethod
47-
def convertParameter(p):
71+
def convert_parameter(p):
4872
if p.parameter_type == api.INT:
4973
# Default value for INT parameter step is 1
5074
step = 1

pkg/suggestion/v1beta1/optuna/base_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from pkg.suggestion.v1beta1.internal.constant import INTEGER, DOUBLE, CATEGORICAL, DISCRETE, MAX_GOAL
1919
from pkg.suggestion.v1beta1.internal.trial import Assignment
20+
from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
2021

2122

2223
class BaseOptunaService(object):
@@ -48,6 +49,10 @@ def _create_sampler(self):
4849
elif self.algorithm_name == "random":
4950
return optuna.samplers.RandomSampler(**self.algorithm_config)
5051

52+
elif self.algorithm_name == "grid":
53+
combinations = HyperParameterSearchSpace.convert_to_combinations(self.search_space)
54+
return optuna.samplers.GridSampler(combinations, **self.algorithm_config)
55+
5156
def get_suggestions(self, trials, current_request_number):
5257
if len(trials) != 0:
5358
self._tell(trials)

pkg/suggestion/v1beta1/optuna/service.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import threading
1616
import grpc
1717
import logging
18+
import itertools
1819

1920
from pkg.apis.manager.v1beta1.python import api_pb2
2021
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
21-
from pkg.suggestion.v1beta1.internal.constant import INTEGER, DOUBLE
2222
from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
2323
from pkg.suggestion.v1beta1.internal.trial import Trial, Assignment
2424
from pkg.suggestion.v1beta1.optuna.base_service import BaseOptunaService
@@ -55,7 +55,7 @@ def GetSuggestions(self, request, context):
5555

5656
def ValidateAlgorithmSettings(self, request, context):
5757
is_valid, message = OptimizerConfiguration.validate_algorithm_spec(
58-
request.experiment.spec.algorithm)
58+
request.experiment)
5959
if not is_valid:
6060
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
6161
context.set_details(message)
@@ -86,6 +86,9 @@ class OptimizerConfiguration(object):
8686
"random": {
8787
"seed": lambda x: int(x),
8888
},
89+
"grid": {
90+
"seed": lambda x: int(x),
91+
}
8992
}
9093

9194
@classmethod
@@ -110,7 +113,8 @@ def convert_algorithm_spec(cls, algorithm_spec):
110113
return algorithm_spec.algorithm_name, config
111114

112115
@classmethod
113-
def validate_algorithm_spec(cls, algorithm_spec):
116+
def validate_algorithm_spec(cls, experiment):
117+
algorithm_spec = experiment.spec.algorithm
114118
algorithm_name = algorithm_spec.algorithm_name
115119
algorithm_settings = algorithm_spec.algorithm_settings
116120

@@ -120,6 +124,10 @@ def validate_algorithm_spec(cls, algorithm_spec):
120124
return cls._validate_cmaes_setting(algorithm_settings)
121125
elif algorithm_name == "random":
122126
return cls._validate_random_setting(algorithm_settings)
127+
elif algorithm_name == "grid":
128+
return cls._validate_grid_setting(experiment)
129+
else:
130+
return False, "unknown algorithm name {}".format(algorithm_name)
123131

124132
@classmethod
125133
def _validate_tpe_setting(cls, algorithm_spec):
@@ -178,3 +186,34 @@ def _validate_random_setting(cls, algorithm_settings):
178186
exception=e)
179187

180188
return True, ""
189+
190+
@classmethod
191+
def _validate_grid_setting(cls, experiment):
192+
algorithm_settings = experiment.spec.algorithm.algorithm_settings
193+
search_space = HyperParameterSearchSpace.convert(experiment)
194+
195+
for s in algorithm_settings:
196+
try:
197+
if s.name == "random_state":
198+
if not int(s.value) >= 0:
199+
return False, ""
200+
else:
201+
return False, "unknown setting {} for algorithm grid".format(s.name)
202+
203+
except Exception as e:
204+
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
205+
exception=e)
206+
207+
try:
208+
combinations = HyperParameterSearchSpace.convert_to_combinations(search_space)
209+
num_combinations = len(list(itertools.product(*combinations.values())))
210+
max_trial_count = experiment.spec.max_trial_count
211+
if max_trial_count > num_combinations:
212+
return False, "Max Trial Count: {max_trial} > all possible search combinations: {combinations}".\
213+
format(max_trial=max_trial_count, combinations=num_combinations)
214+
215+
except Exception as e:
216+
return False, "failed to validate parameters({parameters}): {exception}".\
217+
format(parameters=search_space.params, exception=e)
218+
219+
return True, ""

0 commit comments

Comments
 (0)