Skip to content

Commit 5304afe

Browse files
authored
Updated ConvBlock dropout hyperparamter to have a range (#1708)
* Updated ConvBlock dropout hyperparamter to have a range * Adjusted dropout docstrings to fix ci failure * Adjust linting errors
1 parent f9a8451 commit 5304afe

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

autokeras/blocks/basic.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ class ConvBlock(block_module.Block):
262262
unspecified, it will be tuned automatically.
263263
separable: Boolean. Whether to use separable conv layers.
264264
If left unspecified, it will be tuned automatically.
265-
dropout: Float. Between 0 and 1. The dropout rate for after the
266-
convolutional layers. If left unspecified, it will be tuned
267-
automatically.
265+
dropout: Float or kerastuner.engine.hyperparameters.
266+
Choice range Between 0 and 1.
267+
The dropout rate after convolutional layers.
268+
If left unspecified, it will be tuned automatically.
268269
"""
269270

270271
def __init__(
@@ -275,7 +276,7 @@ def __init__(
275276
filters: Optional[Union[int, hyperparameters.Choice]] = None,
276277
max_pooling: Optional[bool] = None,
277278
separable: Optional[bool] = None,
278-
dropout: Optional[float] = None,
279+
dropout: Optional[Union[float, hyperparameters.Choice]] = None,
279280
**kwargs,
280281
):
281282
super().__init__(**kwargs)
@@ -303,7 +304,11 @@ def __init__(
303304
)
304305
self.max_pooling = max_pooling
305306
self.separable = separable
306-
self.dropout = dropout
307+
self.dropout = utils.get_hyperparameter(
308+
dropout,
309+
hyperparameters.Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
310+
float,
311+
)
307312

308313
def get_config(self):
309314
config = super().get_config()
@@ -315,7 +320,7 @@ def get_config(self):
315320
"filters": hyperparameters.serialize(self.filters),
316321
"max_pooling": self.max_pooling,
317322
"separable": self.separable,
318-
"dropout": self.dropout,
323+
"dropout": hyperparameters.serialize(self.dropout),
319324
}
320325
)
321326
return config
@@ -326,6 +331,7 @@ def from_config(cls, config):
326331
config["num_blocks"] = hyperparameters.deserialize(config["num_blocks"])
327332
config["num_layers"] = hyperparameters.deserialize(config["num_layers"])
328333
config["filters"] = hyperparameters.deserialize(config["filters"])
334+
config["dropout"] = hyperparameters.deserialize(config["dropout"])
329335
return cls(**config)
330336

331337
def build(self, hp, inputs=None):
@@ -350,11 +356,6 @@ def build(self, hp, inputs=None):
350356
max_pooling = hp.Boolean("max_pooling", default=True)
351357
pool = layer_utils.get_max_pooling(input_node.shape)
352358

353-
if self.dropout is not None:
354-
dropout = self.dropout
355-
else:
356-
dropout = hp.Choice("dropout", [0.0, 0.25, 0.5], default=0)
357-
358359
for i in range(utils.add_to_hp(self.num_blocks, hp)):
359360
for j in range(utils.add_to_hp(self.num_layers, hp)):
360361
output_node = conv(
@@ -370,8 +371,10 @@ def build(self, hp, inputs=None):
370371
kernel_size - 1,
371372
padding=self._get_padding(kernel_size - 1, output_node),
372373
)(output_node)
373-
if dropout > 0:
374-
output_node = layers.Dropout(dropout)(output_node)
374+
if utils.add_to_hp(self.dropout, hp) > 0:
375+
output_node = layers.Dropout(utils.add_to_hp(self.dropout, hp))(
376+
output_node
377+
)
375378
return output_node
376379

377380
@staticmethod

0 commit comments

Comments
 (0)