Skip to content

Commit 747cc7f

Browse files
committed
fix enas example codes
1 parent bc865fe commit 747cc7f

File tree

1 file changed

+16
-9
lines changed
  • examples/v1beta1/trial-images/enas-cnn-cifar10

1 file changed

+16
-9
lines changed

examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from keras.datasets import cifar10
1717
from ModelConstructor import ModelConstructor
1818
from tensorflow.keras.utils import to_categorical
19-
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
2019
from keras.preprocessing.image import ImageDataGenerator
20+
import tensorflow as tf
2121
import argparse
2222

2323
if __name__ == "__main__":
@@ -50,16 +50,23 @@
5050

5151
print("\n>>> Constructing Model...")
5252
constructor = ModelConstructor(arch, nn_config)
53-
test_model = constructor.build_model()
54-
print(">>> Model Constructed Successfully\n")
5553

56-
if num_gpus > 1:
57-
test_model = multi_gpu_model(test_model, gpus=num_gpus)
54+
num_physical_gpus = len(tf.config.experimental.list_physical_devices('GPU'))
55+
if 1 <= num_gpus <= num_physical_gpus:
56+
devices = ["/gpu:"+str(i) for i in range(num_physical_gpus)]
57+
else:
58+
num_physical_cpu = len(tf.config.experimental.list_physical_devices('CPU'))
59+
devices = ["/cpu:"+str(j) for j in range(num_physical_cpu)]
60+
61+
strategy = tf.distribute.MirroredStrategy(devices)
62+
with strategy.scope():
63+
test_model = constructor.build_model()
64+
test_model.summary()
65+
test_model.compile(loss=keras.losses.categorical_crossentropy,
66+
optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4),
67+
metrics=['accuracy'])
5868

59-
test_model.summary()
60-
test_model.compile(loss=keras.losses.categorical_crossentropy,
61-
optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4),
62-
metrics=['accuracy'])
69+
print(">>> Model Constructed Successfully\n")
6370

6471
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
6572
x_train = x_train.astype('float32')

0 commit comments

Comments
 (0)