Skip to content

Commit cdab893

Browse files
committed
fix enas example codes
1 parent bc865fe commit cdab893

File tree

1 file changed

+17
-9
lines changed
  • examples/v1beta1/trial-images/enas-cnn-cifar10

1 file changed

+17
-9
lines changed

examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from keras.datasets import cifar10
1717
from ModelConstructor import ModelConstructor
1818
from tensorflow.keras.utils import to_categorical
19-
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
2019
from keras.preprocessing.image import ImageDataGenerator
20+
import tensorflow as tf
2121
import argparse
2222

2323
if __name__ == "__main__":
@@ -50,16 +50,24 @@
5050

5151
print("\n>>> Constructing Model...")
5252
constructor = ModelConstructor(arch, nn_config)
53-
test_model = constructor.build_model()
54-
print(">>> Model Constructed Successfully\n")
5553

56-
if num_gpus > 1:
57-
test_model = multi_gpu_model(test_model, gpus=num_gpus)
54+
physical_gpus_num = len(tf.config.experimental.list_physical_devices('GPU'))
55+
devices = []
56+
if 1 <= num_gpus <= physical_gpus_num:
57+
devices = ["/gpu:"+str(i) for i in range(physical_gpus_num)]
58+
else:
59+
physical_cpu_num = len(tf.config.experimental.list_physical_devices('CPU'))
60+
devices = ["/cpu:"+str(j) for j in range(physical_cpu_num)]
61+
62+
strategy = tf.distribute.MirroredStrategy(devices=devices)
63+
with strategy.scope():
64+
test_model = constructor.build_model()
65+
test_model.summary()
66+
test_model.compile(loss=keras.losses.categorical_crossentropy,
67+
optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4),
68+
metrics=['accuracy'])
5869

59-
test_model.summary()
60-
test_model.compile(loss=keras.losses.categorical_crossentropy,
61-
optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4),
62-
metrics=['accuracy'])
70+
print(">>> Model Constructed Successfully\n")
6371

6472
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
6573
x_train = x_train.astype('float32')

0 commit comments

Comments
 (0)