Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/suggestion/optuna/v1beta1/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
grpcio>=1.41.1
protobuf>=3.19.5, <=3.20.3
googleapis-common-protos==1.53.0
optuna>=3.0.0
optuna==3.3.0
99 changes: 70 additions & 29 deletions pkg/suggestion/v1beta1/optuna/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


class OptunaService(api_pb2_grpc.SuggestionServicer, HealthServicer):

def __init__(self):
super(OptunaService, self).__init__()
self.lock = threading.Lock()
Expand All @@ -39,23 +38,29 @@ def GetSuggestions(self, request, context):
Main function to provide suggestion.
"""
with self.lock:
name, config = OptimizerConfiguration.convert_algorithm_spec(request.experiment.spec.algorithm)
name, config = OptimizerConfiguration.convert_algorithm_spec(
request.experiment.spec.algorithm
)
if self.base_service is None:
search_space = HyperParameterSearchSpace.convert(request.experiment)
self.base_service = BaseOptunaService(
algorithm_name=name,
algorithm_config=config,
search_space=search_space)
search_space=search_space,
)

trials = Trial.convert(request.trials)
list_of_assignments = self.base_service.get_suggestions(trials, request.current_request_number)
list_of_assignments = self.base_service.get_suggestions(
trials, request.current_request_number
)
return api_pb2.GetSuggestionsReply(
parameter_assignments=Assignment.generate(list_of_assignments)
)

def ValidateAlgorithmSettings(self, request, context):
is_valid, message = OptimizerConfiguration.validate_algorithm_spec(
request.experiment)
request.experiment
)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
Expand Down Expand Up @@ -88,7 +93,7 @@ class OptimizerConfiguration(object):
},
"grid": {
"seed": lambda x: int(x),
}
},
}

@classmethod
Expand Down Expand Up @@ -117,11 +122,12 @@ def validate_algorithm_spec(cls, experiment):
algorithm_spec = experiment.spec.algorithm
algorithm_name = algorithm_spec.algorithm_name
algorithm_settings = algorithm_spec.algorithm_settings
parameters = experiment.spec.parameter_specs.parameters

if algorithm_name == "tpe" or algorithm_name == "multivariate-tpe":
return cls._validate_tpe_setting(algorithm_spec)
elif algorithm_name == "cmaes":
return cls._validate_cmaes_setting(algorithm_settings)
return cls._validate_cmaes_setting(algorithm_settings, parameters)
elif algorithm_name == "random":
return cls._validate_random_setting(algorithm_settings)
elif algorithm_name == "grid":
Expand All @@ -138,37 +144,58 @@ def _validate_tpe_setting(cls, algorithm_spec):
try:
if s.name in ["n_startup_trials", "n_ei_candidates", "random_state"]:
if not int(s.value) >= 0:
return False, "{} should be greate or equal than zero".format(s.name)
return False, "{} should be greate or equal than zero".format(
s.name
)
else:
return False, "unknown setting {} for algorithm {}".format(s.name, algorithm_name)
return False, "unknown setting {} for algorithm {}".format(
s.name, algorithm_name
)
except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

return True, ""

@classmethod
def _validate_cmaes_setting(cls, algorithm_settings):
if len(algorithm_settings) < 2:
return False, "cmaes only supports two or more dimensional continuous search space."

def _validate_cmaes_setting(cls, algorithm_settings, parameters):
for s in algorithm_settings:
try:
if s.name == "restart_strategy":
if s.value not in ["ipop", "None", "none"]:
return False, "restart_strategy {} is not supported in CMAES optimization".format(s.value)
return (
False,
"restart_strategy {} is not supported in CMAES optimization".format(
s.value
),
)
elif s.name == "sigma":
if not float(s.value) >= 0:
return False, "sigma should be greate or equal than zero"
elif s.name == "random_state":
if not int(s.value) >= 0:
return False, "random_state should be greate or equal than zero"
else:
return False, "unknown setting {} for algorithm cmaes".format(s.name)
return False, "unknown setting {} for algorithm cmaes".format(
s.name
)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

cnt = 0
for p in parameters:
if p.parameter_type == api_pb2.DOUBLE or p.parameter_type == api_pb2.INT:
cnt += 1
if cnt < 2:
return (
False,
"cmaes only supports two or more dimensional continuous search space.",
)

return True, ""

@classmethod
Expand All @@ -179,11 +206,14 @@ def _validate_random_setting(cls, algorithm_settings):
if not int(s.value) >= 0:
return False, ""
else:
return False, "unknown setting {} for algorithm random".format(s.name)
return False, "unknown setting {} for algorithm random".format(
s.name
)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

return True, ""

Expand All @@ -201,19 +231,30 @@ def _validate_grid_setting(cls, experiment):
return False, "unknown setting {} for algorithm grid".format(s.name)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

try:
combinations = HyperParameterSearchSpace.convert_to_combinations(search_space)
combinations = HyperParameterSearchSpace.convert_to_combinations(
search_space
)
num_combinations = len(list(itertools.product(*combinations.values())))
max_trial_count = experiment.spec.max_trial_count
if max_trial_count > num_combinations:
return False, "Max Trial Count: {max_trial} > all possible search combinations: {combinations}".\
format(max_trial=max_trial_count, combinations=num_combinations)
return (
False,
"Max Trial Count: {max_trial} > all possible search combinations: {combinations}".format(
max_trial=max_trial_count, combinations=num_combinations
),
)

except Exception as e:
return False, "failed to validate parameters({parameters}): {exception}".\
format(parameters=search_space.params, exception=e)
return (
False,
"failed to validate parameters({parameters}): {exception}".format(
parameters=search_space.params, exception=e
),
)

return True, ""
Loading