15
15
import threading
16
16
import grpc
17
17
import logging
18
+ import itertools
18
19
19
20
from pkg .apis .manager .v1beta1 .python import api_pb2
20
21
from pkg .apis .manager .v1beta1 .python import api_pb2_grpc
21
- from pkg .suggestion .v1beta1 .internal .constant import INTEGER , DOUBLE
22
22
from pkg .suggestion .v1beta1 .internal .search_space import HyperParameterSearchSpace
23
23
from pkg .suggestion .v1beta1 .internal .trial import Trial , Assignment
24
24
from pkg .suggestion .v1beta1 .optuna .base_service import BaseOptunaService
@@ -55,7 +55,7 @@ def GetSuggestions(self, request, context):
55
55
56
56
def ValidateAlgorithmSettings (self , request , context ):
57
57
is_valid , message = OptimizerConfiguration .validate_algorithm_spec (
58
- request .experiment . spec . algorithm )
58
+ request .experiment )
59
59
if not is_valid :
60
60
context .set_code (grpc .StatusCode .INVALID_ARGUMENT )
61
61
context .set_details (message )
@@ -86,6 +86,9 @@ class OptimizerConfiguration(object):
86
86
"random" : {
87
87
"seed" : lambda x : int (x ),
88
88
},
89
+ "grid" : {
90
+ "seed" : lambda x : int (x ),
91
+ }
89
92
}
90
93
91
94
@classmethod
@@ -110,7 +113,8 @@ def convert_algorithm_spec(cls, algorithm_spec):
110
113
return algorithm_spec .algorithm_name , config
111
114
112
115
@classmethod
113
- def validate_algorithm_spec (cls , algorithm_spec ):
116
+ def validate_algorithm_spec (cls , experiment ):
117
+ algorithm_spec = experiment .spec .algorithm
114
118
algorithm_name = algorithm_spec .algorithm_name
115
119
algorithm_settings = algorithm_spec .algorithm_settings
116
120
@@ -120,6 +124,10 @@ def validate_algorithm_spec(cls, algorithm_spec):
120
124
return cls ._validate_cmaes_setting (algorithm_settings )
121
125
elif algorithm_name == "random" :
122
126
return cls ._validate_random_setting (algorithm_settings )
127
+ elif algorithm_name == "grid" :
128
+ return cls ._validate_grid_setting (experiment )
129
+ else :
130
+ return False , "unknown algorithm name {}" .format (algorithm_name )
123
131
124
132
@classmethod
125
133
def _validate_tpe_setting (cls , algorithm_spec ):
@@ -178,3 +186,34 @@ def _validate_random_setting(cls, algorithm_settings):
178
186
exception = e )
179
187
180
188
return True , ""
189
+
190
+ @classmethod
191
+ def _validate_grid_setting (cls , experiment ):
192
+ algorithm_settings = experiment .spec .algorithm .algorithm_settings
193
+ search_space = HyperParameterSearchSpace .convert (experiment )
194
+
195
+ for s in algorithm_settings :
196
+ try :
197
+ if s .name == "random_state" :
198
+ if not int (s .value ) >= 0 :
199
+ return False , ""
200
+ else :
201
+ return False , "unknown setting {} for algorithm grid" .format (s .name )
202
+
203
+ except Exception as e :
204
+ return False , "failed to validate {name}({value}): {exception}" .format (name = s .name , value = s .value ,
205
+ exception = e )
206
+
207
+ try :
208
+ combinations = HyperParameterSearchSpace .convert_to_combinations (search_space )
209
+ num_combinations = len (list (itertools .product (* combinations .values ())))
210
+ max_trial_count = experiment .spec .max_trial_count
211
+ if max_trial_count > num_combinations :
212
+ return False , "Max Trial Count: {max_trial} > all possible search combinations: {combinations}" .\
213
+ format (max_trial = max_trial_count , combinations = num_combinations )
214
+
215
+ except Exception as e :
216
+ return False , "failed to validate parameters({parameters}): {exception}" .\
217
+ format (parameters = search_space .params , exception = e )
218
+
219
+ return True , ""
0 commit comments