Skip to content

Commit a1d4cce

Browse files
andsteingcopybara-github
authored andcommitted
Correctly forward config.optim_dtype.
PiperOrigin-RevId: 555396325
1 parent ac6e056 commit a1d4cce

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vit_jax/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def init_model():
140140
optax.sgd(
141141
learning_rate=lr_fn,
142142
momentum=0.9,
143-
accumulator_dtype='bfloat16',
143+
accumulator_dtype=config.optim_dtype,
144144
),
145145
)
146146

@@ -224,7 +224,6 @@ def init_model():
224224
img_sec_core_test = (
225225
config.batch_eval * ds_test.cardinality().numpy() /
226226
(time.time() - lt0) / jax.device_count())
227-
lt0 = time.time()
228227

229228
lr = float(lr_fn(step))
230229
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
@@ -237,6 +236,7 @@ def init_model():
237236
accuracy_test=accuracy_test,
238237
lr=lr,
239238
img_sec_core_test=img_sec_core_test))
239+
lt0, lstep = time.time(), step
240240

241241
# Store checkpoint.
242242
if ((config.checkpoint_every and step % config.eval_every == 0) or
@@ -246,5 +246,6 @@ def init_model():
246246
flax.jax_utils.unreplicate(opt_state_repl), step), step)
247247
logging.info('Stored checkpoint at step %d to "%s"', step,
248248
checkpoint_path)
249+
lt0, lstep = time.time(), step
249250

250251
return flax.jax_utils.unreplicate(params_repl)

0 commit comments

Comments
 (0)