Skip to content

Example using vmap/pmap from Jax? #18570

@asmith26

Description

@asmith26

I've written a custom Jax training loop, but unfortunately my script doesn't seem to be using very much GPU-memory. I've tried increasing the batch_size, but unfortunately that doesn't seem to make much difference. Hence I thought I'd try to increase the throughput with something like vmap or run some models in parallel (where applicable) with something like pmap (or even like flax.linen.vmap().

Are there any examples regarding how to do this? I've come across this guide, but this only appears to be for multiple devices.

I might be misunderstanding my problem, but many thanks for any help! :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions