File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ def init_model():
140
140
optax .sgd (
141
141
learning_rate = lr_fn ,
142
142
momentum = 0.9 ,
143
- accumulator_dtype = 'bfloat16' ,
143
+ accumulator_dtype = config . optim_dtype ,
144
144
),
145
145
)
146
146
@@ -224,7 +224,6 @@ def init_model():
224
224
img_sec_core_test = (
225
225
config .batch_eval * ds_test .cardinality ().numpy () /
226
226
(time .time () - lt0 ) / jax .device_count ())
227
- lt0 = time .time ()
228
227
229
228
lr = float (lr_fn (step ))
230
229
logging .info (f'Step: { step } ' # pylint: disable=logging-fstring-interpolation
@@ -237,6 +236,7 @@ def init_model():
237
236
accuracy_test = accuracy_test ,
238
237
lr = lr ,
239
238
img_sec_core_test = img_sec_core_test ))
239
+ lt0 , lstep = time .time (), step
240
240
241
241
# Store checkpoint.
242
242
if ((config .checkpoint_every and step % config .eval_every == 0 ) or
@@ -246,5 +246,6 @@ def init_model():
246
246
flax .jax_utils .unreplicate (opt_state_repl ), step ), step )
247
247
logging .info ('Stored checkpoint at step %d to "%s"' , step ,
248
248
checkpoint_path )
249
+ lt0 , lstep = time .time (), step
249
250
250
251
return flax .jax_utils .unreplicate (params_repl )
You can’t perform that action at this time.
0 commit comments