-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Closed
Description
I've written a custom Jax training loop, but unfortunately my script doesn't seem to be using very much GPU-memory. I've tried increasing the batch_size
, but unfortunately that doesn't seem to make much difference. Hence I thought I'd try to increase the throughput with something like vmap or run some models in parallel (where applicable) with something like pmap (or even like flax.linen.vmap()
.
Are there any examples regarding how to do this? I've come across this guide, but this only appears to be for multiple devices.
I might be misunderstanding my problem, but many thanks for any help! :)
Metadata
Metadata
Assignees
Labels
No labels