26
26
from pkg .suggestion .v1beta1 .nas .enas .AlgorithmSettings import (
27
27
parseAlgorithmSettings , algorithmSettingsValidator , enableNoneSettingsList )
28
28
from pkg .suggestion .v1beta1 .internal .base_health_service import HealthServicer
29
+ from pkg .suggestion .v1beta1 .nas .common .validation import validate_operations
29
30
30
31
31
32
class EnasExperiment :
@@ -161,67 +162,29 @@ def __init__(self, logger=None):
161
162
162
163
def ValidateAlgorithmSettings (self , request , context ):
163
164
self .logger .info ("Validate Algorithm Settings start" )
164
- graph_config = request .experiment .spec .nas_config .graph_config
165
+ nas_config = request .experiment .spec .nas_config
166
+ graph_config = nas_config .graph_config
165
167
166
168
# Validate GraphConfig
167
169
# Check InputSize
168
170
if not graph_config .input_sizes :
169
- return self .SetValidateContextError (context , "Missing InputSizes in GraphConfig:\n {}" .format (graph_config ))
171
+ return self .set_validate_context_error (context ,
172
+ "Missing InputSizes in GraphConfig:\n {}" .format (graph_config ))
170
173
171
174
# Check OutputSize
172
175
if not graph_config .output_sizes :
173
- return self .SetValidateContextError (context , "Missing OutputSizes in GraphConfig:\n {}" .format (graph_config ))
176
+ return self .set_validate_context_error (context ,
177
+ "Missing OutputSizes in GraphConfig:\n {}" .format (graph_config ))
174
178
175
179
# Check NumLayers
176
180
if not graph_config .num_layers :
177
- return self .SetValidateContextError (context , "Missing NumLayers in GraphConfig:\n {}" .format (graph_config ))
178
-
179
- # Validate each operation
180
- operations_list = list (
181
- request .experiment .spec .nas_config .operations .operation )
182
- for operation in operations_list :
183
-
184
- # Check OperationType
185
- if not operation .operation_type :
186
- return self .SetValidateContextError (context , "Missing operationType in Operation:\n {}" .format (
187
- operation ))
188
-
189
- # Check ParameterConfigs
190
- if not operation .parameter_specs .parameters :
191
- return self .SetValidateContextError (context , "Missing ParameterConfigs in Operation:\n {}" .format (
192
- operation ))
193
-
194
- # Validate each ParameterConfig in Operation
195
- parameters_list = list (operation .parameter_specs .parameters )
196
- for parameter in parameters_list :
197
-
198
- # Check Name
199
- if not parameter .name :
200
- return self .SetValidateContextError (context , "Missing Name in ParameterConfig:\n {}" .format (
201
- parameter ))
202
-
203
- # Check ParameterType
204
- if not parameter .parameter_type :
205
- return self .SetValidateContextError (context , "Missing ParameterType in ParameterConfig:\n {}" .format (
206
- parameter ))
207
-
208
- # Check List in Categorical or Discrete Type
209
- if parameter .parameter_type == api_pb2 .CATEGORICAL or parameter .parameter_type == api_pb2 .DISCRETE :
210
- if not parameter .feasible_space .list :
211
- return self .SetValidateContextError (
212
- context , "Missing List in ParameterConfig.feasibleSpace:\n {}" .format (parameter ))
213
-
214
- # Check Max, Min, Step in Int or Double Type
215
- elif parameter .parameter_type == api_pb2 .INT or parameter .parameter_type == api_pb2 .DOUBLE :
216
- if not parameter .feasible_space .min and not parameter .feasible_space .max :
217
- return self .SetValidateContextError (
218
- context , "Missing Max and Min in ParameterConfig.feasibleSpace:\n {}" .format (parameter ))
219
-
220
- if (parameter .parameter_type == api_pb2 .DOUBLE and
221
- (not parameter .feasible_space .step or float (parameter .feasible_space .step ) <= 0 )):
222
- return self .SetValidateContextError (
223
- context , "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n {}" .format (
224
- parameter ))
181
+ return self .set_validate_context_error (context ,
182
+ "Missing NumLayers in GraphConfig:\n {}" .format (graph_config ))
183
+
184
+ # Validate Operations
185
+ is_valid , message = validate_operations (nas_config .operations .operation )
186
+ if not is_valid :
187
+ return self .set_validate_context_error (context , message )
225
188
226
189
# Validate Algorithm Settings
227
190
settings_raw = request .experiment .spec .algorithm .algorithm_settings
@@ -233,14 +196,15 @@ def ValidateAlgorithmSettings(self, request, context):
233
196
setting_range = algorithmSettingsValidator [setting .name ][1 ]
234
197
try :
235
198
converted_value = setting_type (setting .value )
236
- except Exception :
237
- return self .SetValidateContextError (context , "Algorithm Setting {} must be {} type" .format (
238
- setting .name , setting_type .__name__ ))
199
+ except Exception as e :
200
+ return self .set_validate_context_error (context ,
201
+ "Algorithm Setting {} must be {} type: exception {}" .format (
202
+ setting .name , setting_type .__name__ , e ))
239
203
240
204
if setting_type == float :
241
205
if (converted_value <= setting_range [0 ] or
242
206
(setting_range [1 ] != 'inf' and converted_value > setting_range [1 ])):
243
- return self .SetValidateContextError (
207
+ return self .set_validate_context_error (
244
208
context , "Algorithm Setting {}: {} with {} type must be in range ({}, {}]" .format (
245
209
setting .name ,
246
210
converted_value ,
@@ -250,7 +214,7 @@ def ValidateAlgorithmSettings(self, request, context):
250
214
)
251
215
252
216
elif converted_value < setting_range [0 ]:
253
- return self .SetValidateContextError (
217
+ return self .set_validate_context_error (
254
218
context , "Algorithm Setting {}: {} with {} type must be in range [{}, {})" .format (
255
219
setting .name ,
256
220
converted_value ,
@@ -259,12 +223,13 @@ def ValidateAlgorithmSettings(self, request, context):
259
223
setting_range [1 ])
260
224
)
261
225
else :
262
- return self .SetValidateContextError (context , "Unknown Algorithm Setting name: {}" .format (setting .name ))
226
+ return self .set_validate_context_error (context ,
227
+ "Unknown Algorithm Setting name: {}" .format (setting .name ))
263
228
264
229
self .logger .info ("All Experiment Settings are Valid" )
265
230
return api_pb2 .ValidateAlgorithmSettingsReply ()
266
231
267
- def SetValidateContextError (self , context , error_message ):
232
+ def set_validate_context_error (self , context , error_message ):
268
233
context .set_code (grpc .StatusCode .INVALID_ARGUMENT )
269
234
context .set_details (error_message )
270
235
self .logger .info (error_message )
0 commit comments