Skip to content
63 changes: 63 additions & 0 deletions pkg/suggestion/v1beta1/nas/common/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pkg.apis.manager.v1beta1.python import api_pb2


def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):

# Validate each operation
for operation in operations:

# Check OperationType
if not operation.operation_type:
return False, "Missing operationType in Operation:\n{}".format(operation)

# Check ParameterConfigs
if not operation.parameter_specs.parameters:
return False, "Missing ParameterConfigs in Operation:\n{}".format(operation)

# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:

# Check Name
if not parameter.name:
return False, "Missing Name in ParameterConfig:\n{}".format(parameter)

# Check ParameterType
if not parameter.parameter_type:
return False, "Missing ParameterType in ParameterConfig:\n{}".format(parameter)

# Check List in Categorical or Discrete Type
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
if not parameter.feasible_space.list:
return False, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter)

# Check Max, Min, Step in Int or Double Type
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
if not parameter.feasible_space.min and not parameter.feasible_space.max:
return False, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter)

try:
if (parameter.parameter_type == api_pb2.DOUBLE and
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
return False, \
"Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter)
except Exception as e:
return False, \
"failed to validate ParameterConfig.feasibleSpace \n{parameter}):\n{exception}".format(
parameter=parameter, exception=e)

return True, ""
65 changes: 64 additions & 1 deletion pkg/suggestion/v1beta1/nas/darts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import logging
from logging import getLogger, StreamHandler, INFO
import json
import grpc

from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.apis.manager.v1beta1.python import api_pb2
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations


class DartsService(api_pb2_grpc.SuggestionServicer, HealthServicer):
Expand All @@ -36,8 +38,12 @@ def __init__(self):
self.logger.addHandler(handler)
self.logger.propagate = False

# TODO: Add validation
def ValidateAlgorithmSettings(self, request, context):
is_valid, message = validate_algorithm_settings(request.experiment.spec)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
self.logger.error(message)
return api_pb2.ValidateAlgorithmSettingsReply()

def GetSuggestions(self, request, context):
Expand Down Expand Up @@ -134,3 +140,60 @@ def get_algorithm_settings(settings_raw):
algorithm_settings_default[s_name] = s_value

return algorithm_settings_default


def validate_algorithm_settings(spec: api_pb2.ExperimentSpec) -> (bool, str):
# Validate Operations
is_valid, message = validate_operations(spec.nas_config.operations.operation)
if not is_valid:
return False, message

# Validate AlgorithmSettings
is_valid, message = validate_algorithm_spec(spec.algorithm.algorithm_settings)
if not is_valid:
return False, message

return True, ""


def validate_algorithm_spec(algorithm_settings: list[api_pb2.AlgorithmSetting]) -> (bool, str):
for s in algorithm_settings:
try:
if s.name == "num_epochs":
if not int(s.value) > 0:
return False, "{} should be greate than zero".format(s.name)

# Validate learning rate
if s.name in ["w_lr", "w_lr_min", "alpha_lr"]:
if not float(s.value) >= 0.0:
return False, "{} should be greate or equal than zero".format(s.name)

# Validate weight decay
if s.name in ["w_weight_decay", "alpha_weight_decay"]:
if not float(s.value) >= 0.0:
return False, "{} should be greate or equal than zero".format(s.name)

# Validate w_momentum and w_grad_clip
if s.name in ["w_momentum", "w_grad_clip"]:
if not float(s.value) >= 0.0:
return False, "{} should be greate or equal than zero".format(s.name)

if s.name == "batch_size":
if s.value is not "None":
if not int(s.value) >= 1:
return False, "batch_size should be greate or equal than one"

if s.name == "num_workers":
if not int(s.value) >= 0:
return False, "num_workers should be greate or equal than zero"

# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
if s.name in ["init_channels", "print_step", "num_nodes", "stem_multiplier"]:
if not int(s.value) >= 1:
return False, "{} should be greate or equal than one".format(s.name)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)

return True, ""
81 changes: 23 additions & 58 deletions pkg/suggestion/v1beta1/nas/enas/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import (
parseAlgorithmSettings, algorithmSettingsValidator, enableNoneSettingsList)
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations


class EnasExperiment:
Expand Down Expand Up @@ -161,67 +162,29 @@ def __init__(self, logger=None):

def ValidateAlgorithmSettings(self, request, context):
self.logger.info("Validate Algorithm Settings start")
graph_config = request.experiment.spec.nas_config.graph_config
nas_config = request.experiment.spec.nas_config
graph_config = nas_config.graph_config

# Validate GraphConfig
# Check InputSize
if not graph_config.input_sizes:
return self.SetValidateContextError(context, "Missing InputSizes in GraphConfig:\n{}".format(graph_config))
return self.set_validate_context_error(context,
"Missing InputSizes in GraphConfig:\n{}".format(graph_config))

# Check OutputSize
if not graph_config.output_sizes:
return self.SetValidateContextError(context, "Missing OutputSizes in GraphConfig:\n{}".format(graph_config))
return self.set_validate_context_error(context,
"Missing OutputSizes in GraphConfig:\n{}".format(graph_config))

# Check NumLayers
if not graph_config.num_layers:
return self.SetValidateContextError(context, "Missing NumLayers in GraphConfig:\n{}".format(graph_config))

# Validate each operation
operations_list = list(
request.experiment.spec.nas_config.operations.operation)
for operation in operations_list:

# Check OperationType
if not operation.operation_type:
return self.SetValidateContextError(context, "Missing operationType in Operation:\n{}".format(
operation))

# Check ParameterConfigs
if not operation.parameter_specs.parameters:
return self.SetValidateContextError(context, "Missing ParameterConfigs in Operation:\n{}".format(
operation))

# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:

# Check Name
if not parameter.name:
return self.SetValidateContextError(context, "Missing Name in ParameterConfig:\n{}".format(
parameter))

# Check ParameterType
if not parameter.parameter_type:
return self.SetValidateContextError(context, "Missing ParameterType in ParameterConfig:\n{}".format(
parameter))

# Check List in Categorical or Discrete Type
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
if not parameter.feasible_space.list:
return self.SetValidateContextError(
context, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter))

# Check Max, Min, Step in Int or Double Type
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
if not parameter.feasible_space.min and not parameter.feasible_space.max:
return self.SetValidateContextError(
context, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter))

if (parameter.parameter_type == api_pb2.DOUBLE and
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
return self.SetValidateContextError(
context, "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(
parameter))
return self.set_validate_context_error(context,
"Missing NumLayers in GraphConfig:\n{}".format(graph_config))

# Validate Operations
is_valid, message = validate_operations(nas_config.operations.operation)
if not is_valid:
return self.set_validate_context_error(context, message)

# Validate Algorithm Settings
settings_raw = request.experiment.spec.algorithm.algorithm_settings
Expand All @@ -233,14 +196,15 @@ def ValidateAlgorithmSettings(self, request, context):
setting_range = algorithmSettingsValidator[setting.name][1]
try:
converted_value = setting_type(setting.value)
except Exception:
return self.SetValidateContextError(context, "Algorithm Setting {} must be {} type".format(
setting.name, setting_type.__name__))
except Exception as e:
return self.set_validate_context_error(context,
"Algorithm Setting {} must be {} type: exception {}".format(
setting.name, setting_type.__name__, e))

if setting_type == float:
if (converted_value <= setting_range[0] or
(setting_range[1] != 'inf' and converted_value > setting_range[1])):
return self.SetValidateContextError(
return self.set_validate_context_error(
context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
setting.name,
converted_value,
Expand All @@ -250,7 +214,7 @@ def ValidateAlgorithmSettings(self, request, context):
)

elif converted_value < setting_range[0]:
return self.SetValidateContextError(
return self.set_validate_context_error(
context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
setting.name,
converted_value,
Expand All @@ -259,12 +223,13 @@ def ValidateAlgorithmSettings(self, request, context):
setting_range[1])
)
else:
return self.SetValidateContextError(context, "Unknown Algorithm Setting name: {}".format(setting.name))
return self.set_validate_context_error(context,
"Unknown Algorithm Setting name: {}".format(setting.name))

self.logger.info("All Experiment Settings are Valid")
return api_pb2.ValidateAlgorithmSettingsReply()

def SetValidateContextError(self, context, error_message):
def set_validate_context_error(self, context, error_message):
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(error_message)
self.logger.info(error_message)
Expand Down
64 changes: 58 additions & 6 deletions test/unit/v1beta1/suggestion/test_darts_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@

from pkg.apis.manager.v1beta1.python import api_pb2

from pkg.suggestion.v1beta1.nas.darts.service import DartsService
from pkg.suggestion.v1beta1.nas.darts.service import DartsService, validate_algorithm_spec


class TestDarts(unittest.TestCase):
def setUp(self):
servicers = {
services = {
api_pb2.DESCRIPTOR.services_by_name['Suggestion']: DartsService(
)
}

self.test_server = grpc_testing.server_from_dictionary(
servicers, grpc_testing.strict_real_time())
services, grpc_testing.strict_real_time())

def test_get_suggestion(self):
experiment = api_pb2.Experiment(
Expand Down Expand Up @@ -102,17 +102,69 @@ def test_get_suggestion(self):

exp_search_space = ["separable_convolution_3x3", "separable_convolution_5x5"]
for pa in response.parameter_assignments[0].assignments:
if (pa.name == "algorithm-settings"):
if pa.name == "algorithm-settings":
algorithm_settings = pa.value.replace("\'", "\"")
algorithm_settings = json.loads(algorithm_settings)
self.assertDictContainsSubset(exp_algorithm_settings, algorithm_settings)
elif (pa.name == "num-layers"):
elif pa.name == "num-layers":
self.assertEqual(exp_num_layers, int(pa.value))
elif (pa.name == "search-space"):
elif pa.name == "search-space":
search_space = pa.value.replace("\'", "\"")
search_space = json.loads(search_space)
self.assertEqual(exp_search_space, search_space)

def test_validate_algorithm_spec(self):

# Valid Case
valid = [
api_pb2.AlgorithmSetting(name="num_epoch", value="10"),
api_pb2.AlgorithmSetting(name="w_lr", value="0.01"),
api_pb2.AlgorithmSetting(name="w_lr_min", value="0.01"),
api_pb2.AlgorithmSetting(name="alpha_lr", value="0.01"),
api_pb2.AlgorithmSetting(name="w_weight_decay", value="0.25"),
api_pb2.AlgorithmSetting(name="alpha_weight_decay", value="0.25"),
api_pb2.AlgorithmSetting(name="w_momentum", value="0.9"),
api_pb2.AlgorithmSetting(name="w_grad_clip", value="5.0"),
api_pb2.AlgorithmSetting(name="batch_size", value="100"),
api_pb2.AlgorithmSetting(name="num_workers", value="0"),
api_pb2.AlgorithmSetting(name="init_channels", value="1"),
api_pb2.AlgorithmSetting(name="print_step", value="100"),
api_pb2.AlgorithmSetting(name="num_nodes", value="4"),
api_pb2.AlgorithmSetting(name="stem_multiplier", value="3"),
]
is_valid, _ = validate_algorithm_spec(valid)
self.assertEqual(is_valid, True)

# Invalid num_epochs
invalid = [api_pb2.AlgorithmSetting(name="num_epochs", value="0")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)

# Invalid w_lr
invalid = [api_pb2.AlgorithmSetting(name="w_lr", value="-0.1")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)

# Invalid alpha_weight_decay
invalid = [api_pb2.AlgorithmSetting(name="alpha_weight_decay", value="-0.02")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)

# Invalid w_momentum
invalid = [api_pb2.AlgorithmSetting(name="w_momentum", value="-0.8")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)

# Invalid batch_size
invalid = [api_pb2.AlgorithmSetting(name="batch_size", value="0")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)

# Invalid print_step
invalid = [api_pb2.AlgorithmSetting(name="print_step", value="0")]
is_valid, _ = validate_algorithm_spec(invalid)
self.assertEqual(is_valid, False)


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions test/unit/v1beta1/suggestion/test_enas_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@

class TestEnas(unittest.TestCase):
def setUp(self):
servicers = {
services = {
api_pb2.DESCRIPTOR.services_by_name['Suggestion']: EnasService(
)
}

self.test_server = grpc_testing.server_from_dictionary(
servicers, grpc_testing.strict_real_time())
services, grpc_testing.strict_real_time())

def test_get_suggestion(self):
trials = [
Expand Down
Loading