You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your great work. I have a question about the behavior of torch2jax wrapped function on jax.vmap. I see it uses ffi.ffi_call("torch_call", outshapes, vmap_method="sequential"). Does it mean that the vmap call will be called sequentially in pytorch? If so, is there any way to improve it by using other vmap_method options?