Skip to content

Commit 702a11b

Browse files
andsteingcopybara-github
authored andcommitted
Correctly forward config.optim_dtype.
PiperOrigin-RevId: 555396325
1 parent ac6e056 commit 702a11b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: [3.8]
16+
python-version: ['3.10']
1717
steps:
1818
- name: Cancel previous
1919
uses: styfle/[email protected]

vit_jax/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def init_model():
140140
optax.sgd(
141141
learning_rate=lr_fn,
142142
momentum=0.9,
143-
accumulator_dtype='bfloat16',
143+
accumulator_dtype=config.optim_dtype,
144144
),
145145
)
146146

@@ -212,7 +212,7 @@ def init_model():
212212
(step == total_steps)):
213213

214214
accuracies = []
215-
lt0 = time.time()
215+
tt0 = time.time()
216216
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
217217
logits = infer_fn_repl(
218218
dict(params=params_repl), test_batch['image'])
@@ -223,8 +223,7 @@ def init_model():
223223
accuracy_test = np.mean(accuracies)
224224
img_sec_core_test = (
225225
config.batch_eval * ds_test.cardinality().numpy() /
226-
(time.time() - lt0) / jax.device_count())
227-
lt0 = time.time()
226+
(time.time() - tt0) / jax.device_count())
228227

229228
lr = float(lr_fn(step))
230229
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
@@ -237,14 +236,17 @@ def init_model():
237236
accuracy_test=accuracy_test,
238237
lr=lr,
239238
img_sec_core_test=img_sec_core_test))
239+
lt0 += time.time() - tt0
240240

241241
# Store checkpoint.
242242
if ((config.checkpoint_every and step % config.eval_every == 0) or
243243
step == total_steps):
244+
tt0 = time.time()
244245
checkpoint_path = flax_checkpoints.save_checkpoint(
245246
workdir, (flax.jax_utils.unreplicate(params_repl),
246247
flax.jax_utils.unreplicate(opt_state_repl), step), step)
247248
logging.info('Stored checkpoint at step %d to "%s"', step,
248249
checkpoint_path)
250+
lt0 += time.time() - tt0
249251

250252
return flax.jax_utils.unreplicate(params_repl)

0 commit comments

Comments
 (0)