@@ -140,7 +140,7 @@ def init_model():
140
140
optax .sgd (
141
141
learning_rate = lr_fn ,
142
142
momentum = 0.9 ,
143
- accumulator_dtype = 'bfloat16' ,
143
+ accumulator_dtype = config . optim_dtype ,
144
144
),
145
145
)
146
146
@@ -212,7 +212,7 @@ def init_model():
212
212
(step == total_steps )):
213
213
214
214
accuracies = []
215
- lt0 = time .time ()
215
+ tt0 = time .time ()
216
216
for test_batch in input_pipeline .prefetch (ds_test , config .prefetch ):
217
217
logits = infer_fn_repl (
218
218
dict (params = params_repl ), test_batch ['image' ])
@@ -223,8 +223,7 @@ def init_model():
223
223
accuracy_test = np .mean (accuracies )
224
224
img_sec_core_test = (
225
225
config .batch_eval * ds_test .cardinality ().numpy () /
226
- (time .time () - lt0 ) / jax .device_count ())
227
- lt0 = time .time ()
226
+ (time .time () - tt0 ) / jax .device_count ())
228
227
229
228
lr = float (lr_fn (step ))
230
229
logging .info (f'Step: { step } ' # pylint: disable=logging-fstring-interpolation
@@ -237,14 +236,17 @@ def init_model():
237
236
accuracy_test = accuracy_test ,
238
237
lr = lr ,
239
238
img_sec_core_test = img_sec_core_test ))
239
+ lt0 += time .time () - tt0
240
240
241
241
# Store checkpoint.
242
242
if ((config .checkpoint_every and step % config .eval_every == 0 ) or
243
243
step == total_steps ):
244
+ tt0 = time .time ()
244
245
checkpoint_path = flax_checkpoints .save_checkpoint (
245
246
workdir , (flax .jax_utils .unreplicate (params_repl ),
246
247
flax .jax_utils .unreplicate (opt_state_repl ), step ), step )
247
248
logging .info ('Stored checkpoint at step %d to "%s"' , step ,
248
249
checkpoint_path )
250
+ lt0 += time .time () - tt0
249
251
250
252
return flax .jax_utils .unreplicate (params_repl )
0 commit comments