@@ -167,17 +167,17 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
167
167
return False , "{} should be greater than zero" .format (s .name )
168
168
169
169
# Validate learning rate
170
- if s .name in [ "w_lr" , "w_lr_min" , "alpha_lr" ] :
170
+ if s .name in { "w_lr" , "w_lr_min" , "alpha_lr" } :
171
171
if not float (s .value ) >= 0.0 :
172
172
return False , "{} should be greater than or equal to zero" .format (s .name )
173
173
174
174
# Validate weight decay
175
- if s .name in [ "w_weight_decay" , "alpha_weight_decay" ] :
175
+ if s .name in { "w_weight_decay" , "alpha_weight_decay" } :
176
176
if not float (s .value ) >= 0.0 :
177
177
return False , "{} should be greater than or equal to zero" .format (s .name )
178
178
179
179
# Validate w_momentum and w_grad_clip
180
- if s .name in [ "w_momentum" , "w_grad_clip" ] :
180
+ if s .name in { "w_momentum" , "w_grad_clip" } :
181
181
if not float (s .value ) >= 0.0 :
182
182
return False , "{} should be greater than or equal to zero" .format (s .name )
183
183
@@ -190,7 +190,7 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
190
190
return False , "num_workers should be greater than or equal to zero"
191
191
192
192
# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
193
- if s .name in [ "init_channels" , "print_step" , "num_nodes" , "stem_multiplier" ] :
193
+ if s .name in { "init_channels" , "print_step" , "num_nodes" , "stem_multiplier" } :
194
194
if not int (s .value ) >= 1 :
195
195
return False , "{} should be greater than or equal to one" .format (s .name )
196
196
0 commit comments