|
39 | 39 | GlobalParams = collections.namedtuple('GlobalParams', [ |
40 | 40 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', |
41 | 41 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', |
42 | | - 'drop_connect_rate', 'depth_divisor', 'min_depth']) |
| 42 | + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) |
43 | 43 |
|
44 | 44 | # Parameters for an individual model block |
45 | 45 | BlockArgs = collections.namedtuple('BlockArgs', [ |
@@ -475,7 +475,7 @@ def efficientnet_params(model_name): |
475 | 475 |
|
476 | 476 |
|
477 | 477 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, |
478 | | - dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000): |
| 478 | + dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): |
479 | 479 | """Create BlockArgs and GlobalParams for efficientnet model. |
480 | 480 |
|
481 | 481 | Args: |
@@ -517,6 +517,7 @@ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None |
517 | 517 | drop_connect_rate=drop_connect_rate, |
518 | 518 | depth_divisor=8, |
519 | 519 | min_depth=None, |
| 520 | + include_top=include_top, |
520 | 521 | ) |
521 | 522 |
|
522 | 523 | return blocks_args, global_params |
|
0 commit comments