Skip to content

Commit 5d083e0

Browse files
committed
[review] fix vaiolation comments
1 parent 9c6efdf commit 5d083e0

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

pkg/suggestion/v1beta1/nas/darts/service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,26 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
171171
# Validate weight decay
172172
if s.name in ["w_weight_decay", "alpha_weight_decay"]:
173173
if not float(s.value) >= 0.0:
174-
return False, "{} should be greate or equal than zero".format(s.name)
174+
return False, "{} should be greater than or equal to zero".format(s.name)
175175

176176
# Validate w_momentum and w_grad_clip
177177
if s.name in ["w_momentum", "w_grad_clip"]:
178178
if not float(s.value) >= 0.0:
179-
return False, "{} should be greate or equal than zero".format(s.name)
179+
return False, "{} should be greater than or equal to zero".format(s.name)
180180

181181
if s.name == "batch_size":
182182
if s.value is not "None":
183183
if not int(s.value) >= 1:
184-
return False, "batch_size should be greate or equal than one"
184+
return False, "batch_size should be greater than or equal to one"
185185

186186
if s.name == "num_workers":
187187
if not int(s.value) >= 0:
188-
return False, "num_workers should be greate or equal than zero"
188+
return False, "num_workers should be greater than or equal to zero"
189189

190190
# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
191191
if s.name in ["init_channels", "print_step", "num_nodes", "stem_multiplier"]:
192192
if not int(s.value) >= 1:
193-
return False, "{} should be greate or equal than one".format(s.name)
193+
return False, "{} should be greater than or equal to one".format(s.name)
194194

195195
except Exception as e:
196196
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,

test/unit/v1beta1/suggestion/test_darts_service.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pkg.apis.manager.v1beta1.python import api_pb2
2121

22-
from pkg.suggestion.v1beta1.nas.darts.service import DartsService, validate_algorithm_spec
22+
from pkg.suggestion.v1beta1.nas.darts.service import DartsService, validate_algorithm_settings
2323

2424

2525
class TestDarts(unittest.TestCase):
@@ -132,37 +132,37 @@ def test_validate_algorithm_spec(self):
132132
api_pb2.AlgorithmSetting(name="num_nodes", value="4"),
133133
api_pb2.AlgorithmSetting(name="stem_multiplier", value="3"),
134134
]
135-
is_valid, _ = validate_algorithm_spec(valid)
135+
is_valid, _ = validate_algorithm_settings(valid)
136136
self.assertEqual(is_valid, True)
137137

138138
# Invalid num_epochs
139139
invalid = [api_pb2.AlgorithmSetting(name="num_epochs", value="0")]
140-
is_valid, _ = validate_algorithm_spec(invalid)
140+
is_valid, _ = validate_algorithm_settings(invalid)
141141
self.assertEqual(is_valid, False)
142142

143143
# Invalid w_lr
144144
invalid = [api_pb2.AlgorithmSetting(name="w_lr", value="-0.1")]
145-
is_valid, _ = validate_algorithm_spec(invalid)
145+
is_valid, _ = validate_algorithm_settings(invalid)
146146
self.assertEqual(is_valid, False)
147147

148148
# Invalid alpha_weight_decay
149149
invalid = [api_pb2.AlgorithmSetting(name="alpha_weight_decay", value="-0.02")]
150-
is_valid, _ = validate_algorithm_spec(invalid)
150+
is_valid, _ = validate_algorithm_settings(invalid)
151151
self.assertEqual(is_valid, False)
152152

153153
# Invalid w_momentum
154154
invalid = [api_pb2.AlgorithmSetting(name="w_momentum", value="-0.8")]
155-
is_valid, _ = validate_algorithm_spec(invalid)
155+
is_valid, _ = validate_algorithm_settings(invalid)
156156
self.assertEqual(is_valid, False)
157157

158158
# Invalid batch_size
159159
invalid = [api_pb2.AlgorithmSetting(name="batch_size", value="0")]
160-
is_valid, _ = validate_algorithm_spec(invalid)
160+
is_valid, _ = validate_algorithm_settings(invalid)
161161
self.assertEqual(is_valid, False)
162162

163163
# Invalid print_step
164164
invalid = [api_pb2.AlgorithmSetting(name="print_step", value="0")]
165-
is_valid, _ = validate_algorithm_spec(invalid)
165+
is_valid, _ = validate_algorithm_settings(invalid)
166166
self.assertEqual(is_valid, False)
167167

168168

0 commit comments

Comments
 (0)