|
19 | 19 |
|
20 | 20 | from pkg.apis.manager.v1beta1.python import api_pb2
|
21 | 21 |
|
22 |
| -from pkg.suggestion.v1beta1.nas.darts.service import DartsService, validate_algorithm_spec |
| 22 | +from pkg.suggestion.v1beta1.nas.darts.service import DartsService, validate_algorithm_settings |
23 | 23 |
|
24 | 24 |
|
25 | 25 | class TestDarts(unittest.TestCase):
|
@@ -132,37 +132,37 @@ def test_validate_algorithm_spec(self):
|
132 | 132 | api_pb2.AlgorithmSetting(name="num_nodes", value="4"),
|
133 | 133 | api_pb2.AlgorithmSetting(name="stem_multiplier", value="3"),
|
134 | 134 | ]
|
135 |
| - is_valid, _ = validate_algorithm_spec(valid) |
| 135 | + is_valid, _ = validate_algorithm_settings(valid) |
136 | 136 | self.assertEqual(is_valid, True)
|
137 | 137 |
|
138 | 138 | # Invalid num_epochs
|
139 | 139 | invalid = [api_pb2.AlgorithmSetting(name="num_epochs", value="0")]
|
140 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 140 | + is_valid, _ = validate_algorithm_settings(invalid) |
141 | 141 | self.assertEqual(is_valid, False)
|
142 | 142 |
|
143 | 143 | # Invalid w_lr
|
144 | 144 | invalid = [api_pb2.AlgorithmSetting(name="w_lr", value="-0.1")]
|
145 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 145 | + is_valid, _ = validate_algorithm_settings(invalid) |
146 | 146 | self.assertEqual(is_valid, False)
|
147 | 147 |
|
148 | 148 | # Invalid alpha_weight_decay
|
149 | 149 | invalid = [api_pb2.AlgorithmSetting(name="alpha_weight_decay", value="-0.02")]
|
150 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 150 | + is_valid, _ = validate_algorithm_settings(invalid) |
151 | 151 | self.assertEqual(is_valid, False)
|
152 | 152 |
|
153 | 153 | # Invalid w_momentum
|
154 | 154 | invalid = [api_pb2.AlgorithmSetting(name="w_momentum", value="-0.8")]
|
155 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 155 | + is_valid, _ = validate_algorithm_settings(invalid) |
156 | 156 | self.assertEqual(is_valid, False)
|
157 | 157 |
|
158 | 158 | # Invalid batch_size
|
159 | 159 | invalid = [api_pb2.AlgorithmSetting(name="batch_size", value="0")]
|
160 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 160 | + is_valid, _ = validate_algorithm_settings(invalid) |
161 | 161 | self.assertEqual(is_valid, False)
|
162 | 162 |
|
163 | 163 | # Invalid print_step
|
164 | 164 | invalid = [api_pb2.AlgorithmSetting(name="print_step", value="0")]
|
165 |
| - is_valid, _ = validate_algorithm_spec(invalid) |
| 165 | + is_valid, _ = validate_algorithm_settings(invalid) |
166 | 166 | self.assertEqual(is_valid, False)
|
167 | 167 |
|
168 | 168 |
|
|
0 commit comments