Skip to content

[Question] jax.vmap & nested jax.vmap behavior #28

@ZaberKo

Description

@ZaberKo

Hello,

Thanks for your great work. I have a question about the behavior of torch2jax wrapped function on jax.vmap. I see it uses ffi.ffi_call("torch_call", outshapes, vmap_method="sequential"). Does it mean that the vmap call will be called sequentially in pytorch? If so, is there any way to improve it by using other vmap_method options?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions