Skip to content

Commit cae12e6

Browse files
tenzen-yanencore94
andauthored
Implement validations for darts suggestion service (#1926)
* implement validation for darts service * Update pkg/suggestion/v1beta1/nas/darts/service.py Co-authored-by: Jaeyeon Kim(김재연) <[email protected]> * Update pkg/suggestion/v1beta1/nas/darts/service.py Co-authored-by: Jaeyeon Kim(김재연) <[email protected]> * [review] delete todo comment * [review] change function name validate_algorithm_settings to validate_algorithm_spec * [review] fix vaiolation comments * [review] fix condition to validate batch_size * [review] add comment for developers * [review] use set instead of list Co-authored-by: Jaeyeon Kim(김재연) <[email protected]>
1 parent 478e01d commit cae12e6

File tree

7 files changed

+398
-70
lines changed

7 files changed

+398
-70
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2022 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pkg.apis.manager.v1beta1.python import api_pb2
16+
17+
18+
def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):
19+
20+
# Validate each operation
21+
for operation in operations:
22+
23+
# Check OperationType
24+
if not operation.operation_type:
25+
return False, "Missing operationType in Operation:\n{}".format(operation)
26+
27+
# Check ParameterConfigs
28+
if not operation.parameter_specs.parameters:
29+
return False, "Missing ParameterConfigs in Operation:\n{}".format(operation)
30+
31+
# Validate each ParameterConfig in Operation
32+
parameters_list = list(operation.parameter_specs.parameters)
33+
for parameter in parameters_list:
34+
35+
# Check Name
36+
if not parameter.name:
37+
return False, "Missing Name in ParameterConfig:\n{}".format(parameter)
38+
39+
# Check ParameterType
40+
if not parameter.parameter_type:
41+
return False, "Missing ParameterType in ParameterConfig:\n{}".format(parameter)
42+
43+
# Check List in Categorical or Discrete Type
44+
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
45+
if not parameter.feasible_space.list:
46+
return False, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter)
47+
48+
# Check Max, Min, Step in Int or Double Type
49+
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
50+
if not parameter.feasible_space.min and not parameter.feasible_space.max:
51+
return False, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter)
52+
53+
try:
54+
if (parameter.parameter_type == api_pb2.DOUBLE and
55+
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
56+
return False, \
57+
"Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter)
58+
except Exception as e:
59+
return False, \
60+
"failed to validate ParameterConfig.feasibleSpace \n{parameter}):\n{exception}".format(
61+
parameter=parameter, exception=e)
62+
63+
return True, ""

pkg/suggestion/v1beta1/nas/darts/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ Currently, it supports running only on single GPU and second-order approximation
6969

7070
- Integrate E2E test in CI. Create simple example, which can run on CPU.
7171

72-
- Add validation to Suggestion service.
73-
7472
- Support multi GPU training. Add functionality to select GPU for training.
7573

7674
- Support DARTS in Katib UI.

pkg/suggestion/v1beta1/nas/darts/service.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import logging
1616
from logging import getLogger, StreamHandler, INFO
1717
import json
18+
import grpc
1819

1920
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
2021
from pkg.apis.manager.v1beta1.python import api_pb2
2122
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
23+
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations
2224

2325

2426
class DartsService(api_pb2_grpc.SuggestionServicer, HealthServicer):
@@ -36,8 +38,12 @@ def __init__(self):
3638
self.logger.addHandler(handler)
3739
self.logger.propagate = False
3840

39-
# TODO: Add validation
4041
def ValidateAlgorithmSettings(self, request, context):
42+
is_valid, message = validate_algorithm_spec(request.experiment.spec)
43+
if not is_valid:
44+
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
45+
context.set_details(message)
46+
self.logger.error(message)
4147
return api_pb2.ValidateAlgorithmSettingsReply()
4248

4349
def GetSuggestions(self, request, context):
@@ -130,7 +136,66 @@ def get_algorithm_settings(settings_raw):
130136

131137
for setting in settings_raw:
132138
s_name = setting.name
133-
s_value = setting.value
139+
s_value = None if setting.value == "None" else setting.value
134140
algorithm_settings_default[s_name] = s_value
135141

136142
return algorithm_settings_default
143+
144+
145+
def validate_algorithm_spec(spec: api_pb2.ExperimentSpec) -> (bool, str):
146+
# Validate Operations
147+
is_valid, message = validate_operations(spec.nas_config.operations.operation)
148+
if not is_valid:
149+
return False, message
150+
151+
# Validate AlgorithmSettings
152+
is_valid, message = validate_algorithm_settings(spec.algorithm.algorithm_settings)
153+
if not is_valid:
154+
return False, message
155+
156+
return True, ""
157+
158+
159+
# validate_algorithm_settings is implemented based on quark0/darts and pt.darts.
160+
# quark0/darts: https://github.com/quark0/darts
161+
# pt.darts: https://github.com/khanrc/pt.darts
162+
def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSetting]) -> (bool, str):
163+
for s in algorithm_settings:
164+
try:
165+
if s.name == "num_epochs":
166+
if not int(s.value) > 0:
167+
return False, "{} should be greater than zero".format(s.name)
168+
169+
# Validate learning rate
170+
if s.name in {"w_lr", "w_lr_min", "alpha_lr"}:
171+
if not float(s.value) >= 0.0:
172+
return False, "{} should be greater than or equal to zero".format(s.name)
173+
174+
# Validate weight decay
175+
if s.name in {"w_weight_decay", "alpha_weight_decay"}:
176+
if not float(s.value) >= 0.0:
177+
return False, "{} should be greater than or equal to zero".format(s.name)
178+
179+
# Validate w_momentum and w_grad_clip
180+
if s.name in {"w_momentum", "w_grad_clip"}:
181+
if not float(s.value) >= 0.0:
182+
return False, "{} should be greater than or equal to zero".format(s.name)
183+
184+
if s.name == "batch_size":
185+
if s.value != "None" and not int(s.value) >= 1:
186+
return False, "batch_size should be greater than or equal to one"
187+
188+
if s.name == "num_workers":
189+
if not int(s.value) >= 0:
190+
return False, "num_workers should be greater than or equal to zero"
191+
192+
# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
193+
if s.name in {"init_channels", "print_step", "num_nodes", "stem_multiplier"}:
194+
if not int(s.value) >= 1:
195+
return False, "{} should be greater than or equal to one".format(s.name)
196+
197+
except Exception as e:
198+
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
199+
exception=e)
200+
201+
return True, ""

pkg/suggestion/v1beta1/nas/enas/service.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import (
2727
parseAlgorithmSettings, algorithmSettingsValidator, enableNoneSettingsList)
2828
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
29+
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations
2930

3031

3132
class EnasExperiment:
@@ -161,67 +162,29 @@ def __init__(self, logger=None):
161162

162163
def ValidateAlgorithmSettings(self, request, context):
163164
self.logger.info("Validate Algorithm Settings start")
164-
graph_config = request.experiment.spec.nas_config.graph_config
165+
nas_config = request.experiment.spec.nas_config
166+
graph_config = nas_config.graph_config
165167

166168
# Validate GraphConfig
167169
# Check InputSize
168170
if not graph_config.input_sizes:
169-
return self.SetValidateContextError(context, "Missing InputSizes in GraphConfig:\n{}".format(graph_config))
171+
return self.set_validate_context_error(context,
172+
"Missing InputSizes in GraphConfig:\n{}".format(graph_config))
170173

171174
# Check OutputSize
172175
if not graph_config.output_sizes:
173-
return self.SetValidateContextError(context, "Missing OutputSizes in GraphConfig:\n{}".format(graph_config))
176+
return self.set_validate_context_error(context,
177+
"Missing OutputSizes in GraphConfig:\n{}".format(graph_config))
174178

175179
# Check NumLayers
176180
if not graph_config.num_layers:
177-
return self.SetValidateContextError(context, "Missing NumLayers in GraphConfig:\n{}".format(graph_config))
178-
179-
# Validate each operation
180-
operations_list = list(
181-
request.experiment.spec.nas_config.operations.operation)
182-
for operation in operations_list:
183-
184-
# Check OperationType
185-
if not operation.operation_type:
186-
return self.SetValidateContextError(context, "Missing operationType in Operation:\n{}".format(
187-
operation))
188-
189-
# Check ParameterConfigs
190-
if not operation.parameter_specs.parameters:
191-
return self.SetValidateContextError(context, "Missing ParameterConfigs in Operation:\n{}".format(
192-
operation))
193-
194-
# Validate each ParameterConfig in Operation
195-
parameters_list = list(operation.parameter_specs.parameters)
196-
for parameter in parameters_list:
197-
198-
# Check Name
199-
if not parameter.name:
200-
return self.SetValidateContextError(context, "Missing Name in ParameterConfig:\n{}".format(
201-
parameter))
202-
203-
# Check ParameterType
204-
if not parameter.parameter_type:
205-
return self.SetValidateContextError(context, "Missing ParameterType in ParameterConfig:\n{}".format(
206-
parameter))
207-
208-
# Check List in Categorical or Discrete Type
209-
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
210-
if not parameter.feasible_space.list:
211-
return self.SetValidateContextError(
212-
context, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter))
213-
214-
# Check Max, Min, Step in Int or Double Type
215-
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
216-
if not parameter.feasible_space.min and not parameter.feasible_space.max:
217-
return self.SetValidateContextError(
218-
context, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter))
219-
220-
if (parameter.parameter_type == api_pb2.DOUBLE and
221-
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
222-
return self.SetValidateContextError(
223-
context, "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(
224-
parameter))
181+
return self.set_validate_context_error(context,
182+
"Missing NumLayers in GraphConfig:\n{}".format(graph_config))
183+
184+
# Validate Operations
185+
is_valid, message = validate_operations(nas_config.operations.operation)
186+
if not is_valid:
187+
return self.set_validate_context_error(context, message)
225188

226189
# Validate Algorithm Settings
227190
settings_raw = request.experiment.spec.algorithm.algorithm_settings
@@ -233,14 +196,15 @@ def ValidateAlgorithmSettings(self, request, context):
233196
setting_range = algorithmSettingsValidator[setting.name][1]
234197
try:
235198
converted_value = setting_type(setting.value)
236-
except Exception:
237-
return self.SetValidateContextError(context, "Algorithm Setting {} must be {} type".format(
238-
setting.name, setting_type.__name__))
199+
except Exception as e:
200+
return self.set_validate_context_error(context,
201+
"Algorithm Setting {} must be {} type: exception {}".format(
202+
setting.name, setting_type.__name__, e))
239203

240204
if setting_type == float:
241205
if (converted_value <= setting_range[0] or
242206
(setting_range[1] != 'inf' and converted_value > setting_range[1])):
243-
return self.SetValidateContextError(
207+
return self.set_validate_context_error(
244208
context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
245209
setting.name,
246210
converted_value,
@@ -250,7 +214,7 @@ def ValidateAlgorithmSettings(self, request, context):
250214
)
251215

252216
elif converted_value < setting_range[0]:
253-
return self.SetValidateContextError(
217+
return self.set_validate_context_error(
254218
context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
255219
setting.name,
256220
converted_value,
@@ -259,12 +223,13 @@ def ValidateAlgorithmSettings(self, request, context):
259223
setting_range[1])
260224
)
261225
else:
262-
return self.SetValidateContextError(context, "Unknown Algorithm Setting name: {}".format(setting.name))
226+
return self.set_validate_context_error(context,
227+
"Unknown Algorithm Setting name: {}".format(setting.name))
263228

264229
self.logger.info("All Experiment Settings are Valid")
265230
return api_pb2.ValidateAlgorithmSettingsReply()
266231

267-
def SetValidateContextError(self, context, error_message):
232+
def set_validate_context_error(self, context, error_message):
268233
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
269234
context.set_details(error_message)
270235
self.logger.info(error_message)

0 commit comments

Comments
 (0)